whitelist = {
# below are some jit functions
'wait', 'fork', 'parse_type_comment', 'import_ir_module',
- 'to_batch_graph', 'import_ir_module_from_buffer',
- 'register_batch_operator', 'merge_type_from_type_comment',
+ 'import_ir_module_from_buffer', 'merge_type_from_type_comment',
# below are symbols mistakely binded to torch.*, but should
# go to torch.nn.functional.* instead
import random
from typing import List, Dict, Optional, Tuple
from torch.jit.frontend import NotSupportedError
-from torch.jit import BatchTensor
from torch import Tensor
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
self.assertTrue(tested_blocks)
-class TestBatched(TestCase):
- # generate random examples and create an batchtensor with them
- def rand_batch(self, *dims):
- dims = [dim for dim in dims if dim != ()]
- xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
- requires_grad=True) for i in range(dims[0])]
- xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
- return xs, xb
-
- def test_create_batchtensor(self):
- # create from tensorlist
- xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
- self.assertEqual(xs, batch.examples())
- # create from data, mask, dims
- batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
- self.assertEqual(xs, batch2.examples())
- # expand a tensor to a batchtensor given batch_size
- xs = torch.rand(3, 4, 5)
- batch3 = BatchTensor(xs, 2)
- xs = xs.unsqueeze(0)
- self.assertEqual([xs, xs], batch3.examples())
-
- def test_batch_elementwise_unary(self):
- @torch.jit.batch(batch_size=4)
- def tanh(a):
- return torch.tanh(a)
-
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- res_batch = tanh(batch)
- res = [torch.tanh(xs[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_elementwise_binary(self):
- @torch.jit.batch(batch_size=4)
- def add(a, b):
- return a + b
-
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- xs2, batch2 = xs, batch
- res_batch = add(batch, batch2)
- res = [torch.add(xs[j], xs2[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- # test broadcast
- xs, batch = self.rand_batch(4, (False, 3), (False, 2))
- b = torch.rand(3, 2)
- res_batch = add(batch, b)
- res = [torch.add(xs[j], b) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_mm(self):
- @torch.jit.batch(batch_size=4)
- def mm(a, b):
- return torch.mm(a, b)
-
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
- res_batch = mm(batch, batch2)
- res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- # test broadcast
- b = torch.rand(2, 4)
- res_batch = mm(batch, b)
- res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_matmul(self):
- @torch.jit.batch(batch_size=4)
- def matmul(a, b):
- return torch.matmul(a, b)
-
- def matmul_test(xs, batch, xs2, batch2):
- ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
- ybs = matmul(batch, batch2)
- self.assertEqual(ys, ybs.examples())
-
- # 1 dimension * 1 dimension
- xs, batch = self.rand_batch(4, (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2))
- matmul_test(xs, batch, xs2, batch2)
- # 1 dimension * 2 dimension
- xs, batch = self.rand_batch(4, (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
- matmul_test(xs, batch, xs2, batch2)
- # 2 dimension * 1 dimensions
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2))
- matmul_test(xs, batch, xs2, batch2)
- # 2 dimension * 2 dimension
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
- matmul_test(xs, batch, xs2, batch2)
-
- def test_batch_select(self):
- @torch.jit.batch(batch_size=4)
- def select(x):
- return torch.select(x, 1, 0)
-
- xs, batch = self.rand_batch(4, (True, 3), (True, 2))
- res_batch = select(batch)
- res = [torch.select(xs[j], 1, 0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- xs, batch = self.rand_batch(4, (False, 3), (True, 2))
- res_batch = select(batch)
- res = [torch.select(xs[j], 1, 0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_index_select(self):
- @torch.jit.batch(batch_size=4)
- def index_select(x, ind):
- return x.index_select(1, ind)
-
- xs, batch = self.rand_batch(4, (False, 5), (True, 2))
- ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
- ind_batch = BatchTensor(ind, torch.tensor([]).byte())
- res_batch = index_select(batch, ind_batch)
- res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_where(self):
- @torch.jit.batch(batch_size=4)
- def where(c, a, b):
- return torch.where(c, a, b)
-
- xs, batch = self.rand_batch(4, (False, 3), (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
-
- dims = [4, (False, 3), (False, 2)]
- xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
- batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
-
- res_batch = where(batch_cond, batch, batch2)
- res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_argmax(self):
- @torch.jit.batch(batch_size=4)
- def argmax(a):
- return torch.argmax(a, 1)
-
- xs, batch = self.rand_batch(4, (True, 5), (True, 6))
- res_batch = argmax(batch)
- res = [torch.argmax(xs[j], 1) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- @torch.jit.batch(batch_size=4)
- def argmax(a):
- return torch.argmax(a, 1, False)
-
- res_batch = argmax(batch)
- res = [torch.argmax(xs[j], 1, False) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_topk(self):
- @torch.jit.batch(batch_size=4)
- def topk(a):
- return torch.topk(a, 3, 1)
-
- xs, batch = self.rand_batch(4, (False, 5), (True, 6))
-
- # along static dim
- res_batch = topk(batch)
- res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
- res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
- self.assertEqual(res, res_batch[0].examples())
- self.assertEqual(res_idx, res_batch[1].examples())
-
- @torch.jit.batch(batch_size=4)
- def topk(a):
- return torch.topk(a, 1, 2)
-
- # along dynamic dim
- res_batch = topk(batch)
- res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
- res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
- self.assertEqual(res, res_batch[0].examples())
- self.assertEqual(res_idx, res_batch[1].examples())
-
- def test_batch_softmax(self):
- @torch.jit.batch(batch_size=4)
- def softmax(a):
- return torch.softmax(a, 1)
-
- xs, batch = self.rand_batch(4, (False, 5), (True, 6))
-
- # along static dim
- res_batch = softmax(batch)
- res = [torch.softmax(xs[j], 1) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- @torch.jit.batch(batch_size=4)
- def softmax(a):
- return torch.softmax(a, 2)
-
- # along dynamic dim
- res_batch = softmax(batch)
- res = [torch.softmax(xs[j], 2) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_view(self):
- @torch.jit.batch(batch_size=4)
- def view(a):
- return a.view([4, -1, 3])
-
- xs, batch = self.rand_batch(4, (True, 5), (False, 3))
- res_batch = view(batch)
- res = [xs[j].view([1, -1, 3]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_cat(self):
- @torch.jit.batch(batch_size=4)
- def cat2(a, b):
- return torch.cat([a, b], 2)
-
- xs, batch = self.rand_batch(4, (True, 5), (False, 3))
- xs2, batch2 = xs, batch
- res_batch = cat2(batch, batch2)
- res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_sum(self):
- @torch.jit.batch(batch_size=4)
- def batch_sum(a):
- return a.sum()
-
- xs, batch = self.rand_batch(4, (True, 5), (False, 3))
- res_batch = batch_sum(batch)
- res = [xs[j].sum().unsqueeze(0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_if_else(self):
- def single_if(a, b):
- if bool(a > b):
- a = a + b
- else:
- a = a - b
- return a
-
- batch_if = torch.jit.batch(batch_size=4)(single_if)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_if(batch_a, batch_b)
- res = [single_if(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_if = torch.jit.script(single_if)
- torch.to_batch_graph(script_if.graph)
-
- def test_if_else_with_scalar(self):
- def single_if(a, b):
- if bool(a > 0.1):
- a = a + b
- else:
- a = a - b
- return a
-
- batch_if = torch.jit.batch(batch_size=4)(single_if)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_if(batch_a, batch_b)
- res = [single_if(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_if = torch.jit.script(single_if)
- torch.to_batch_graph(script_if.graph)
-
- def test_if_noelse(self):
- def single_if(a, b):
- if bool(a > b):
- a = a + b
- return a
-
- batch_if = torch.jit.batch(batch_size=4)(single_if)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_if(batch_a, batch_b)
- res = [single_if(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_if = torch.jit.script(single_if)
- torch.to_batch_graph(script_if.graph)
-
- def test_if_noelse_with_scalar(self):
- def single_if(a, b):
- if bool(a > 0.1):
- a = a + b
- return a
-
- batch_if = torch.jit.batch(batch_size=4)(single_if)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_if(batch_a, batch_b)
- res = [single_if(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_if = torch.jit.script(single_if)
- torch.to_batch_graph(script_if.graph)
-
- def test_while(self):
- def single_while(a, b):
- while bool(a > b):
- a = a - b
- return a
-
- batch_while = torch.jit.batch(batch_size=4)(single_while)
-
- a, batch_a = self.rand_batch(4, ())
- b = [torch.abs(torch.rand(1)) for i in range(4)]
- batch_b = BatchTensor(b, torch.tensor([]).byte())
- res_batch = batch_while(batch_a, batch_b)
- res = [single_while(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_while = torch.jit.script(single_while)
- torch.to_batch_graph(script_while.graph)
-
- def test_for(self):
- def single_for(x, y):
- for _ in range(10):
- x = x + y
- return x
-
- batch_for = torch.jit.batch(batch_size=4)(single_for)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_for(batch_a, batch_b)
- res = [single_for(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_for = torch.jit.script(single_for)
- torch.to_batch_graph(script_for.graph)
-
- def test_lstm(self):
- def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
- for i in range(x_all.size(1)):
- x = x_all.select(1, i)
- i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
- f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
- o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
- # activations
- i_t = torch.sigmoid(i_t)
- f_t = torch.sigmoid(f_t)
- o_t = torch.sigmoid(o_t)
- # cell computations
- c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
- c_t = torch.tanh(c_t)
- c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
- h_t = torch.mul(o_t, torch.tanh(c_t))
- h = h_t
- c = c_t
- return h
-
- LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
-
- batch_size, input_size, hidden_size = 4, 3, 2
- xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
- hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
- cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
-
- # input to hidden weights
- w_xi = torch.rand(input_size, hidden_size)
- w_xf = torch.rand(input_size, hidden_size)
- w_xo = torch.rand(input_size, hidden_size)
- w_xc = torch.rand(input_size, hidden_size)
- # hidden to hidden weights
- w_hi = torch.rand(hidden_size, hidden_size)
- w_hf = torch.rand(hidden_size, hidden_size)
- w_ho = torch.rand(hidden_size, hidden_size)
- w_hc = torch.rand(hidden_size, hidden_size)
- # bias terms
- b_i = torch.rand(hidden_size)
- b_f = torch.rand(hidden_size)
- b_o = torch.rand(hidden_size)
- b_c = torch.rand(hidden_size)
-
- ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
- ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
- self.assertEqual(ys, ybs.examples())
-
- @slowTest
- def test_greedy_search(self):
- def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
- b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
- iter_count = torch.zeros_like(iter_num)
- while bool(iter_count < iter_num):
- iter_count = iter_count + 1
- # LSTM Cell
- i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
- f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
- o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
- # activations
- i_t = torch.sigmoid(i_t)
- f_t = torch.sigmoid(f_t)
- o_t = torch.sigmoid(o_t)
- # cell computations
- c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
- c_t = torch.tanh(c_t)
- c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
- h_t = torch.mul(o_t, torch.tanh(c_t))
- h = h_t
- c = c_t
- # calculate feature with max probability
- s_t = torch.matmul(h_t, w_hs) + b_s
- p_t = torch.softmax(s_t, 1)
- i_t = torch.argmax(p_t, 1)
- x = embed.index_select(1, i_t).squeeze(1)
- return h
-
- greedy_batch = torch.jit.batch(batch_size=4)(greedy)
-
- batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
- xs, batch = self.rand_batch(batch_size, (False, input_size))
- hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
- cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
- embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
- iter_num = [torch.randint(2, 5, (1,)) for i in range(batch_size)]
- iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
-
- # input to hidden weights
- w_xi = torch.rand(input_size, hidden_size)
- w_xf = torch.rand(input_size, hidden_size)
- w_xo = torch.rand(input_size, hidden_size)
- w_xc = torch.rand(input_size, hidden_size)
- # hidden to hidden weights
- w_hi = torch.rand(hidden_size, hidden_size)
- w_hf = torch.rand(hidden_size, hidden_size)
- w_ho = torch.rand(hidden_size, hidden_size)
- w_hc = torch.rand(hidden_size, hidden_size)
- # bias terms
- b_i = torch.rand(hidden_size)
- b_f = torch.rand(hidden_size)
- b_o = torch.rand(hidden_size)
- b_c = torch.rand(hidden_size)
- # hidden to vocab weights, bias
- w_hs = torch.rand(hidden_size, vocab_size)
- b_s = torch.rand(vocab_size)
-
- ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j]) for j in range(batch_size)]
- ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
- self.assertEqual(ys, ybs.examples())
-
- @slowTest
- def test_beam_search(self):
- def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
- b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
- k = 5
- vocab_size = embed.size(1)
- iter_count = torch.zeros_like(iter_num)
- max_len = idx.size(2)
- while bool(iter_count < iter_num):
- iter_count = iter_count + 1
- # LSTM Cell
- i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
- f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
- o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
- # activations
- i_t = torch.sigmoid(i_t)
- f_t = torch.sigmoid(f_t)
- o_t = torch.sigmoid(o_t)
- # cell computations
- c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
- c_t = torch.tanh(c_t)
- c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
- h_t = torch.mul(o_t, torch.tanh(c_t))
- h = h_t
- c = c_t
- # calculate features with max probability
- s_t = torch.matmul(h_t, w_hs) + b_s
- s_t = s_t.view([1, s_t.size(1) * s_t.size(2)])
- p_t = torch.softmax(s_t, 1)
- prob_t, idx_t = torch.topk(p_t, k, 1)
- if(int(idx_t.dim()) > 1):
- idx_t_tmp = idx_t.squeeze(0)
- else:
- idx_t_tmp = idx_t
- new_y = torch.fmod(idx_t_tmp, vocab_size)
- pre_y = idx_t_tmp / vocab_size
- x = embed.index_select(1, new_y)
- h = h_t.index_select(1, pre_y)
- c = c_t.index_select(1, pre_y)
- iter = int(iter_count[0])
- idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y),
- torch.fmod(idx_t, vocab_size).unsqueeze(-1),
- idx.narrow(2, iter, max_len - iter)], 2)
- idx = idx.narrow(2, 0, max_len)
- return idx
-
- beam_batch = torch.jit.batch(batch_size=4)(beam)
-
- k = 5
- batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
- max_len = 5
- xs, batch = self.rand_batch(batch_size, (False, 1), (False, input_size))
- hx, h_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
- cx, c_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
- embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
- iter_num = [torch.randint(2, max_len + 1, (1,)) for i in range(batch_size)]
- iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
-
- # input to hidden weights
- w_xi = torch.rand(input_size, hidden_size)
- w_xf = torch.rand(input_size, hidden_size)
- w_xo = torch.rand(input_size, hidden_size)
- w_xc = torch.rand(input_size, hidden_size)
- # hidden to hidden weights
- w_hi = torch.rand(hidden_size, hidden_size)
- w_hf = torch.rand(hidden_size, hidden_size)
- w_ho = torch.rand(hidden_size, hidden_size)
- w_hc = torch.rand(hidden_size, hidden_size)
- # bias terms
- b_i = torch.rand(1, hidden_size)
- b_f = torch.rand(1, hidden_size)
- b_o = torch.rand(1, hidden_size)
- b_c = torch.rand(1, hidden_size)
- # hidden to vocab weights, bias
- w_hs = torch.rand(hidden_size, vocab_size)
- b_s = torch.rand(1, vocab_size)
-
- idx_batch = torch.jit.BatchTensor(torch.zeros([batch_size, k, max_len], dtype=torch.long),
- torch.zeros([batch_size, 1, max_len]).byte(),
- torch.tensor([0, 1]).byte())
- idx = [torch.zeros([1, k, max_len], dtype=torch.long) for _ in range(batch_size)]
-
- ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
- b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j]))
- for j in range(batch_size)]
- ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
- self.assertEqual(ys, ybs.examples())
-
-
def execWrapper(code, glob, loc):
if PY2:
exec(code) in glob, loc
${TORCH_SRC_DIR}/csrc/autograd/python_variable.cpp
${TORCH_SRC_DIR}/csrc/autograd/python_variable_indexing.cpp
${TORCH_SRC_DIR}/csrc/byte_order.cpp
- ${TORCH_SRC_DIR}/csrc/jit/batched/BatchTensor.cpp
${TORCH_SRC_DIR}/csrc/jit/init.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/peephole.cpp
- ${TORCH_SRC_DIR}/csrc/jit/passes/to_batch.cpp
${TORCH_SRC_DIR}/csrc/jit/python_arg_flatten.cpp
${TORCH_SRC_DIR}/csrc/jit/python_interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/python_ir.cpp
+++ /dev/null
-#include <torch/csrc/jit/batched/BatchTensor.h>
-
-namespace torch {
-namespace jit {
-
-BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims) {
- if (data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1) {
- throw std::runtime_error(
- "malformed MaskedBatch with data.dim(): " + std::to_string(data.dim()) +
- ", mask.dim(): " + std::to_string(mask.dim()) +
- ", dims.size(0): " + std::to_string(dims.size(0)));
- }
- this->data = std::move(data);
- this->mask = std::move(mask);
- this->dims = std::move(dims);
-}
-
-BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size) {
- dims = at::empty(data.dim(), data.options().dtype(at::kByte));
- dims.fill_(0);
- std::vector<int64_t> sizes(data.dim() + 1, -1);
- sizes[0] = batch_size;
- this->data = data.unsqueeze(0).expand(sizes);
- std::vector<int64_t> mask_sizes(data.dim() + 1, 1);
- mask_sizes[0] = batch_size;
- mask = at::empty(mask_sizes, data.options().dtype(at::kByte));
- mask.fill_(1);
-}
-
-BatchTensor::BatchTensor(
- const std::vector<at::Tensor>& datalist,
- at::Tensor dims) {
- auto bs = datalist.size();
- std::vector<int64_t> sizes(dims.size(0) + 1, 0),
- mask_sizes(dims.size(0) + 1, 0);
- sizes[0] = bs;
- mask_sizes[0] = bs;
- for (int64_t i = 1; i < dims.size(0) + 1; i++) {
- for (const auto& x : datalist) {
- sizes[i] = std::max(sizes[i], x.size(i));
- }
- mask_sizes[i] = *dims[i - 1].data<uint8_t>() ? sizes[i] : 1;
- }
- data = at::empty(sizes, datalist[0].options());
- data.fill_(0);
- mask = at::empty(mask_sizes, datalist[0].options().dtype(at::kByte));
- mask.fill_(0);
- for (std::size_t i = 0; i < datalist.size(); i++) {
- auto data_item = data.narrow(0, i, 1);
- auto mask_item = mask.narrow(0, i, 1);
- for (int64_t j = 0; j < dims.size(0); j++) {
- if (*dims[j].data<uint8_t>()) {
- data_item = data_item.narrow(j + 1, 0, datalist[i].size(j + 1));
- mask_item = mask_item.narrow(j + 1, 0, datalist[i].size(j + 1));
- }
- }
- data_item += datalist[i];
- mask_item.fill_(1);
- }
- this->dims = std::move(dims);
-}
-
-std::vector<at::Tensor> BatchTensor::examples() {
- std::vector<at::Tensor> result;
- // calculate number of valid entries in dth dimension of data
- auto mask_sum = [](at::Tensor data, int d) -> int64_t {
- data = data.sum(d, /*keepdim=*/true);
- while (data.dim() >= 1)
- data = data[0];
- return *data.data<int64_t>();
- };
- for (int64_t i = 0; i < data.size(0); i++) {
- auto data_tmp = data.narrow(0, i, 1);
- for (int64_t d = 0; d < dims.size(0); d++) {
- if (*dims[d].data<uint8_t>()) {
- data_tmp = data_tmp.narrow(d + 1, 0, mask_sum(mask[i], d));
- }
- }
- result.push_back(data_tmp);
- }
- return result;
-}
-
-void initBatchTensorBindings(PyObject* module) {
- auto m = py::handle(module).cast<py::module>();
- auto jit = m.def_submodule("_jit");
- py::class_<BatchTensor>(jit, "BatchTensor")
- .def(py::init<at::Tensor, at::Tensor, at::Tensor>())
- .def(py::init<at::Tensor, int64_t>())
- .def(py::init<std::vector<at::Tensor>, at::Tensor>())
- .def("examples", &BatchTensor::examples)
- .def("get_data", &BatchTensor::get_data)
- .def("get_mask", &BatchTensor::get_mask)
- .def("get_dims", &BatchTensor::get_dims);
-}
-
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-#include <ATen/ATen.h>
-#include <ATen/Tensor.h>
-#include <torch/csrc/jit/pybind.h>
-#include <iostream>
-#include <vector>
-
-namespace torch {
-namespace jit {
-struct BatchTensor {
- public:
- BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims);
- // expand a tensor to a batchtensor given batch_size
- BatchTensor(const at::Tensor& data, int64_t batch_size);
- BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dims);
- const char* toString() const {
- return "BatchTensor";
- }
- at::IntArrayRef sizes() const {
- return data.sizes();
- }
- int64_t dim() const {
- return data.dim();
- }
- std::vector<at::Tensor> examples();
- at::Tensor get_data() {
- return data;
- }
- at::Tensor get_mask() {
- return mask;
- }
- at::Tensor get_dims() {
- return dims;
- }
-
- public:
- // data is a Tensor whose size is the batch size in the batch dimension,
- // the size of all examples in static dimensions,
- // and at least as large as the largest example in the batch in dynamic
- // dimensions.
- at::Tensor data;
- // mask is a Tensor whose size is the batch size in the batch dimension,
- // one in static dimensions,
- // and at least as large as the largest example in the batch in dynamic
- // dimensions. Each entry in the mask corresponds to one or more entries in
- // the data array (singleton, i.e., static, dimensions are broadcasted), with
- // a one in the mask denoting that the corresponding data entries represent
- // valid, meaningful data and a zero denoting that they do not.
- at::Tensor mask;
- // dims is a 1-dimensional tensor with a bool for each non-batch dimension,
- // representing whether that dimension is static (False) or dynamic (True).
- at::Tensor dims;
-};
-
-void initBatchTensorBindings(PyObject* module);
-} // namespace jit
-} // namespace torch
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/argument_spec.h>
-#include <torch/csrc/jit/batched/BatchTensor.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/fuser/kernel_cache.h>
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
-#include <torch/csrc/jit/passes/to_batch.h>
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/python_arg_flatten.h>
tracer::initPythonTracerBindings(module);
script::initTreeViewBindings(module);
script::initJitScriptBindings(module);
- initBatchTensorBindings(module);
- initRegisterBatchOpsBindings(module);
}
} // namespace jit
+++ /dev/null
-#include <torch/csrc/jit/passes/to_batch.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
-#include <torch/csrc/jit/script/compiler.h>
-
-namespace torch {
-namespace jit {
-
-std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
- ToBatch::batch_operator_table;
-
-std::shared_ptr<Graph> ToBatch::getBatchOperator(
- const std::string& name,
- int64_t num_inputs) {
- if (batch_operator_table.find(name) == batch_operator_table.end()) {
- throw std::runtime_error(
- "function " + name + " is not supported in batched tensor yet");
- }
- auto ops = batch_operator_table.at(name);
- if (num_inputs == -1) // default function
- return ops[0];
- for (auto op : ops) {
- if (size_t(num_inputs) == op->inputs().size())
- return op;
- }
- throw std::runtime_error(
- "function " + name + " with " + std::to_string(num_inputs) +
- " inputs is not supported in batched tensor yet");
-}
-
-std::vector<Value*> inlineUnpackedCallTo(
- Graph& g,
- Graph& callee,
- ArrayRef<Value*> inputs) {
- return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true);
-}
-
-// replace aten operator node with BatchTensor operator graph
-void ToBatch::visitAten(Node* n, Block* block, Block* res_block) {
- auto res_graph = res_block->owningGraph();
- auto func_name = std::string(n->kind().toUnqualString());
- std::vector<Value*> new_inputs;
- for (Value* input : n->inputs()) {
- if (rn_env.find(input) == rn_env.end()) { // non-tensor input
- auto new_input = batch_map.at(input);
- new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end());
- } else { // batched tensor input
- new_inputs.push_back(rn_env.at(input));
- }
- }
-
- // transform scalar to tensor before pass to batch operator script
- for (auto& input : new_inputs) {
- if (input->type() == IntType::get() || input->type() == FloatType::get() ||
- input->type() == BoolType::get()) {
- auto to_tensor_node = res_graph->createNumToTensor(input);
- res_graph->insertNode(to_tensor_node);
- input = to_tensor_node->output();
- }
- }
-
- auto batch_graph = getBatchOperator(func_name, new_inputs.size());
- auto outputs =
- inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
-
- // Assume all outputs from inlined operator implementation are in the triple
- // form batched tensor or just a single non-tensor.
- if (outputs.size() == 1) {
- // if previous output is scalar, transform new output back to scalar from
- // dynamic
- TypePtr orig_type = n->outputs()[0]->type();
- if (!orig_type->isSubtypeOf(outputs[0]->type())) {
- Symbol op;
- if (orig_type == IntType::get()) {
- op = prim::Int;
- } else if (orig_type == FloatType::get()) {
- op = prim::Float;
- } else if (orig_type == BoolType::get()) {
- op = prim::Bool;
- } else {
- throw std::runtime_error(
- "NYI: scalar types other than int, float, and bool are not supported yet");
- }
- rn_env[n->outputs()[0]] = res_graph->insert(op, {outputs[0]});
- } else {
- rn_env[n->outputs()[0]] = outputs[0];
- }
- } else {
- for (size_t i = 0; i < n->outputs().size(); i++) {
- auto output = n->outputs()[i];
- batch_map[output] = std::vector<Value*>(
- outputs.begin() + i * EXP_BTENSOR_SIZE,
- outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
- }
- }
-}
-
-// clone prim::Constant to new graph
-// batching transformation is applied to the output of prim::NumToTensor.
-// If there is a prim::NumToTensor following prim::Constant, it will be finally
-// transformed to BatchTensor.
-void ToBatch::visitConstant(Node* n, Block* block, Block* res_block) {
- auto res_graph = res_block->owningGraph();
- auto* r_node = res_graph->createClone(n, rn_fn);
- res_block->appendNode(r_node);
- rn_env[n->output()] = r_node->output();
-}
-
-// change return tensor to expanded batched tensor, eg: {data, mask, dims}
-void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block) {
- auto res_graph = res_block->owningGraph();
- auto* r_node = res_graph->createClone(n, rn_fn);
- res_block->appendNode(r_node);
- auto outputs = inlineUnpackedCallTo(
- *res_block->owningGraph(),
- *getBatchOperator("batch_from_scalar_tensor"),
- r_node->outputs());
- batch_map[n->output()] = outputs;
-}
-
-// clone prim::TensorToNum to new graph
-void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block) {
- auto res_graph = res_block->owningGraph();
- if (rn_env.find(n->input()) == rn_env.end()) {
- rn_env[n->input()] = batch_map.at(n->input())[0];
- }
- auto* r_node = res_graph->createClone(n, rn_fn);
- res_block->appendNode(r_node);
- rn_env[n->output()] = r_node->output();
- batch_map[n->output()] = batch_map.at(n->input());
-}
-
-// clone prim::ListConstruct to new graph
-void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block) {
- auto res_graph = res_block->owningGraph();
- if (n->inputs()[0]->type() ==
- TensorType::get()) { // TensorList: expand directly
- std::vector<Value*> inputs;
- for (Value* input : n->inputs()) {
- auto res = batch_map.at(input);
- inputs.insert(inputs.end(), res.begin(), res.end());
- }
- batch_map[n->output()] = inputs;
- } else { // ScalarList: transform to tensor, then transform back
- for (Value* input : n->inputs()) {
- if (rn_env.find(input) == rn_env.end()) {
- rn_env[input] = batch_map.at(input)[0];
- }
- }
- auto* r_node = res_graph->createClone(n, rn_fn);
- res_block->appendNode(r_node);
- // transform int[] to tensor
- auto to_tensor_node =
- res_graph->create(Symbol::fromQualString("aten::_list_to_tensor"));
- to_tensor_node->addInput(r_node->output());
- res_block->appendNode(to_tensor_node);
- rn_env[n->output()] = to_tensor_node->output();
- }
-}
-
-// clang-format off
-// prim::If transformation:
-// elif is not supported
-//
-// transformation example:
-// @torch.jit.batch(batch_size=4)
-// def batch_if(a, b):
-// if a > b:
-// a += b
-// else:
-// a -= b
-// return a
-//
-// original graph:
-// graph(%a.1 : Dynamic
-// %b : Dynamic) {
-// %2 : Dynamic = aten::gt(%a.1, %b)
-// %a : Dynamic = prim::If(%2)
-// block0() {
-// %a.2 : Dynamic = aten::add[alpha={1}](%a.1, %b)
-// -> (%a.2)
-// }
-// block1() {
-// %a.3 : Dynamic = aten::sub[alpha={1}](%a.1, %b)
-// -> (%a.3)
-// }
-// return (%a);
-// }
-//
-// transformed graph:
-// graph(%a.1_data : Dynamic
-// %a.1_mask : Dynamic
-// %a.1_dims : Dynamic
-// %b_data : Dynamic
-// %b_mask : Dynamic
-// %b_dims : Dynamic) {
-// %6 : Dynamic = aten::gt(%a.1_data, %b_data) // calculate condition
-// %7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
-// %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
-// %9 : int = prim::TensorToNum(%6)
-// %10 : Long() = prim::Constant[value={1}]() // if_block
-// %alpha.1 : float = prim::TensorToNum(%10)
-// %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
-// %mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
-// %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
-// %15 : Long() = prim::Constant[value={1}]() // else_block
-// %alpha : float = prim::TensorToNum(%15)
-// %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
-// %mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
-// %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
-// %20 : Dynamic = aten::type_as(%7, %6) // combine two outputs (batch_where)
-// %cond_mask.1 : Dynamic = aten::mul(%6, %20)
-// %22 : int = aten::dim(%cond_mask.1)
-// %23 : int = prim::Constant[value=1]()
-// %24 : int = aten::eq(%22, %23)
-// %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%24)
-// block0() {
-// %28 : int = aten::dim(%data.1)
-// %29 : int = prim::Constant[value=1]()
-// %30 : int = aten::sub(%28, %29)
-// %31 : int = prim::Constant[value=1]()
-// %data.3 : Dynamic = prim::Loop(%30, %31, %cond_mask.1)
-// block0(%_ : int, %34 : Dynamic) {
-// %35 : int = prim::Constant[value=1]()
-// %36 : int = aten::neg(%35)
-// %data.2 : Dynamic = aten::unsqueeze(%34, %36)
-// %38 : int = prim::Constant[value=1]()
-// -> (%38, %data.2)
-// }
-// %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
-// %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
-// -> (%cond_data.1, %cond_mask.2, %data.3)
-// }
-// block1() {
-// -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-// }
-// %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
-// %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
-// %res_dims : Dynamic = aten::__or__(%dims.1, %dims)
-// return (%res_data, %res_mask, %res_dims);
-// }
-// clang-format on
-void ToBatch::visitIf(Node* n, Block* block, Block* res_block) {
- toBatch(n->blocks()[0], res_block);
- toBatch(n->blocks()[1], res_block);
-
- // combine results from two if paths
- for (size_t i = 0; i < n->outputs().size(); i++) {
- std::vector<Value*> inputs;
- if (batch_map.find(n->input()) == batch_map.end()) { // cond is scalar
- inputs.push_back(rn_env.at(n->input()));
- } else { // cond is tensor
- auto cond = batch_map.at(n->input());
- inputs.insert(inputs.end(), cond.begin(), cond.end());
- }
- auto if_output = batch_map.at(n->blocks()[0]->outputs()[i]);
- inputs.insert(inputs.end(), if_output.begin(), if_output.end());
- auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
- inputs.insert(inputs.end(), else_output.begin(), else_output.end());
- auto outputs = inlineUnpackedCallTo(
- *res_block->owningGraph(),
- *getBatchOperator("where", inputs.size()),
- inputs);
- batch_map[n->outputs()[i]] = outputs;
- }
-}
-
-// clang-format off
-// prim::Loop transformation:
-//
-// transformation example:
-// @torch.jit.batch(batch_size=4)
-// def batch_while(a, b):
-// while a > b:
-// a -= b
-// return a
-//
-// original graph:
-// graph(%a.1 : Dynamic
-// %b : Dynamic) {
-// %2 : int = prim::Constant[value={2147483647}]()
-// %3 : Dynamic = aten::gt(%a.1, %b)
-// %a : Dynamic = prim::Loop(%2, %3, %a.1)
-// block0(%4 : Dynamic, %5 : Dynamic) {
-// %a.2 : Dynamic = aten::sub[alpha={1}](%5, %b)
-// %9 : Dynamic = aten::gt(%a.2, %b)
-// -> (%9, %a.2)
-// }
-// return (%a);
-// }
-//
-// transformed graph:
-// graph(%a.1_data : Dynamic
-// %a.1_mask : Dynamic
-// %a.1_dims : Dynamic
-// %b_data : Dynamic
-// %b_mask : Dynamic
-// %b_dims : Dynamic) {
-// %6 : int = prim::Constant[value=2147483647]()
-// %7 : Dynamic = aten::gt(%a.1_data, %b_data)
-// %8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
-// %9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
-// %10 : int = prim::TensorToNum(%7)
-// %11 : Dynamic = aten::mul(%7, %8)
-// %12 : Dynamic = aten::sum(%11)
-// %13 : Dynamic = aten::gt[other={0}](%12) // cond_any
-// %14 : int = prim::TensorToNum(%13)
-// %62 : Dynamic, %63 : Dynamic, %64 : Dynamic, %a : Dynamic, %60 : Dynamic, %61 : Dynamic = prim::Loop(%6, %14, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
-// block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
-// %23 : Long() = prim::Constant[value={1}]()
-// %alpha : float = prim::TensorToNum(%23)
-// %data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha)
-// %mask : Dynamic = aten::mul(%6_mask, %b_mask)
-// %dims : Dynamic = aten::__or__(%6_dims, %b_dims)
-// %28 : Dynamic = aten::gt(%data.1, %b_data)
-// %29 : Dynamic = aten::mul(%mask, %b_mask)
-// %30 : Dynamic = aten::__or__(%dims, %b_dims)
-// %31 : int = prim::TensorToNum(%28)
-// %32 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2) // update outputs (batch_where)
-// %cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %32)
-// %34 : int = aten::dim(%cond_mask.1)
-// %35 : int = prim::Constant[value=1]()
-// %36 : int = aten::eq(%34, %35)
-// %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%36)
-// block0() {
-// %40 : int = aten::dim(%data.1)
-// %41 : int = prim::Constant[value=1]()
-// %42 : int = aten::sub(%40, %41)
-// %43 : int = prim::Constant[value=1]()
-// %data.3 : Dynamic = prim::Loop(%42, %43, %cond_mask.1)
-// block0(%_ : int, %46 : Dynamic) {
-// %47 : int = prim::Constant[value=1]()
-// %48 : int = aten::neg(%47)
-// %data.2 : Dynamic = aten::unsqueeze(%46, %48)
-// %50 : int = prim::Constant[value=1]()
-// -> (%50, %data.2)
-// }
-// %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
-// %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
-// -> (%cond_data.1, %cond_mask.2, %data.3)
-// }
-// block1() {
-// -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-// }
-// %res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data)
-// %res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask)
-// %res_dims : Dynamic = aten::__or__(%dims, %6_dims)
-// %56 : Dynamic = aten::mul(%28, %29)
-// %57 : Dynamic = aten::sum(%56)
-// %58 : Dynamic = aten::gt[other={0}](%57)
-// %59 : int = prim::TensorToNum(%58)
-// -> (%59, %28, %29, %30, %res_data, %res_mask, %res_dims)
-// }
-// return (%a, %60, %61);
-// }
-// clang-format on
-void ToBatch::visitLoop(Node* n, Block* block, Block* res_block) {
- auto res_graph = res_block->owningGraph();
- // bool cond_is_tensor indicates whether cond is tensor
- // cond_is_tensor = false, eg: for loop, n->inputs()[1] = byte()
- // cond_is_tensor = true, eg: in some while loop, cond is a batched tensor,
- // we need to add expanded cond to the inputs of
- // loop node and block, and compute cond_any as
- // cond for while loop
- bool cond_is_tensor = (batch_map.find(n->inputs()[1]) != batch_map.end());
-
- // create prim::Loop node for res_block
-
- // type of cond in loop should be int type
- if (rn_env.at(n->inputs()[0])->type() != IntType::get()) {
- rn_env[n->inputs()[0]] =
- res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])});
- }
- if (cond_is_tensor) {
- auto cond = batch_map.at(n->inputs()[1]);
- auto cond_any = inlineUnpackedCallTo(
- *res_block->owningGraph(), *getBatchOperator("any"), cond);
- rn_env[n->inputs()[1]] = res_graph->insert(prim::Bool, {cond_any[0]});
- }
- for (size_t i = 2; i < n->inputs().size(); i++) {
- auto input = n->inputs()[i];
- rn_env[input] = batch_map.at(input)[0];
- }
- auto* r_node = res_graph->createClone(n, rn_fn, /*copy_blocks=*/false);
-
- // change inputs of prim::Loop
- if (cond_is_tensor) {
- for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
- auto cond = batch_map.at(n->inputs()[1]);
- r_node->insertInput(i + 2, cond[i]);
- }
- }
- for (size_t i = 2; i < n->inputs().size(); i++) {
- for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
- r_node->insertInput(
- (i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 +
- j,
- batch_map.at(n->inputs()[i])[j]);
- }
- }
- res_block->appendNode(r_node);
-
- // create block for Loop node in res_block
- // if cond is tensor: first 4 inputs of block: cond_any, cond_data,
- // cond_mask, cond_dims
- // if cond is not tensor: first 1 input of block: cond
- auto loop_block = r_node->addBlock();
-
- // add inputs
- loop_block->addInput("loop_num");
- loop_block->inputs()[0]->setType(IntType::get());
- rn_env[n->blocks()[0]->inputs()[0]] = loop_block->inputs()[0];
- if (cond_is_tensor) {
- for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
- loop_block->addInput("cond_" + EXP_BTENSOR_NAME[i]);
- }
- }
- for (size_t i = 1; i < n->blocks()[0]->inputs().size(); i++) {
- auto input = n->blocks()[0]->inputs()[i];
- auto name = input->uniqueName();
- for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
- loop_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
- }
- batch_map[input] =
- std::vector<Value*>(loop_block->inputs()
- .slice(
- (i - 1) * EXP_BTENSOR_SIZE + 1 +
- EXP_BTENSOR_SIZE * cond_is_tensor,
- EXP_BTENSOR_SIZE)
- .vec());
- }
-
- toBatch(n->blocks()[0], loop_block);
-
- WithInsertPoint guard(loop_block);
-
- // use where operator to update variables and add to outputs
- for (size_t i = 0; i < n->outputs().size(); i++) {
- std::vector<Value*> inputs, outputs;
- if (cond_is_tensor) {
- for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
- inputs.push_back(loop_block->inputs()[j + 1]);
- }
- auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
- inputs.insert(inputs.end(), data.begin(), data.end());
- for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
- inputs.push_back(
- loop_block
- ->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
- }
- outputs = inlineUnpackedCallTo(
- *res_block->owningGraph(), *getBatchOperator("where"), inputs);
- } else {
- for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
- inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + 1]);
- }
- auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
- inputs.insert(inputs.end(), data.begin(), data.end());
- outputs = inlineUnpackedCallTo(
- *res_block->owningGraph(), *getBatchOperator("update"), inputs);
- }
- batch_map[n->outputs()[i]] = outputs;
- for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
- loop_block->registerOutput(outputs[j]);
- }
- }
-
- // update loop conditions
- if (cond_is_tensor) {
- auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
- auto cond_any = inlineUnpackedCallTo(
- *res_block->owningGraph(), *getBatchOperator("any"), cond);
- auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]});
- loop_block->insertOutput(0, to_bool_output);
- for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
- loop_block->insertOutput(i + 1, cond[i]);
- }
- } else {
- auto cond = rn_env.at(n->blocks()[0]->outputs()[0]);
- loop_block->insertOutput(0, cond);
- }
-
- // change outputs of prim::Loop
- auto size = r_node->outputs().size();
- for (size_t i = 0; i < size; i++) {
- for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
- r_node->insertOutput(i * EXP_BTENSOR_SIZE + j);
- }
- batch_map[n->outputs()[i]] =
- r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
- }
- // add cond to outputs of loop node
- if (cond_is_tensor) {
- for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
- r_node->insertOutput(i);
- }
- }
-}
-
-void ToBatch::toBatch(Block* block, Block* res_block) {
- WithInsertPoint guard(res_block);
-
- // change inputs of block-expand tensor to batchtensor eg: (data, mask, dims)
- // eg: a -> a_data, a_mask, a_dims for block in prim::Loop, register inputs
- // separately to deal with cond
- if (!block->owningNode() || block->owningNode()->kind() != prim::Loop) {
- auto size = block->inputs().size();
- for (size_t i = 0; i < size; i++) {
- auto input = block->inputs()[i];
- auto name = input->uniqueName();
- for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
- res_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
- }
- batch_map[input] =
- std::vector<Value*>(res_block->inputs()
- .slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE)
- .vec());
- }
- }
-
- for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
- auto n = *it;
- if (n->kind().is_aten()) {
- visitAten(n, block, res_block);
- } else if (n->kind().is_prim()) {
- switch (n->kind()) {
- case prim::Constant:
- visitConstant(n, block, res_block);
- break;
- case prim::NumToTensor:
- visitNumToTensor(n, block, res_block);
- break;
- case prim::Bool:
- case prim::Float:
- case prim::Int:
- visitTensorToNum(n, block, res_block);
- break;
- case prim::ListConstruct:
- visitListConstruct(n, block, res_block);
- break;
- case prim::If:
- visitIf(n, block, res_block);
- break;
- case prim::Loop:
- visitLoop(n, block, res_block);
- break;
- default:
- throw std::runtime_error(
- "NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
- }
- } else {
- throw std::runtime_error(
- "NYI: node that is not aten or prim kind is not supported yet");
- }
- }
- // change outputs of block - expand tensor to batchtensor(data, mask, dims)
- // for block in prim::Loop, register outputs separately to deal with cond and
- // cond_any
- //
- // for block in prim::If, register outputs separately by combining
- // outputs from two paths and return
- if (!block->owningNode() ||
- (block->owningNode()->kind() != prim::Loop &&
- block->owningNode()->kind() != prim::If)) {
- for (Value* output : block->outputs()) {
- auto r_output = batch_map.at(output);
- for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
- res_block->registerOutput(r_output[i]);
- }
- }
- }
-}
-
-std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
- // lower the tuple before the pass
- if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) {
- graph = graph->copy();
- auto outs = createTupleUnpack(graph->outputs().at(0));
- graph->eraseOutput(0);
- for (auto o : outs)
- graph->registerOutput(o);
- EliminateDeadCode(graph->block());
- }
- std::shared_ptr<Graph> res_graph = std::make_shared<Graph>();
- ToBatch to_batch;
- to_batch.toBatch(graph->block(), res_graph->block());
-
- // methods should only have a single output, so we pack everything into a
- // tuple
- auto tup =
- res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
- while (res_graph->outputs().size() > 0)
- res_graph->eraseOutput(res_graph->outputs().size() - 1);
- res_graph->registerOutput(tup->output());
- EliminateDeadCode(res_graph->block());
-
- return res_graph;
-}
-
-void initRegisterBatchOpsBindings(PyObject* module) {
- auto m = py::handle(module).cast<py::module>();
- m.def("to_batch_graph", to_batch_graph);
- m.def(
- "register_batch_operator",
- [](std::string name, std::shared_ptr<Graph> graph) {
- ToBatch::batch_operator_table[name].push_back(graph);
- });
-}
-
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/pybind.h>
-
-#include <ATen/ATen.h>
-
-namespace torch {
-namespace jit {
-
-class ToBatch {
- private:
- // number of tensors to represent a expanded BatchTensor. {data, mask, dims}
- // for now.
- const size_t EXP_BTENSOR_SIZE = 3;
- const std::vector<std::string> EXP_BTENSOR_NAME = {"data", "mask", "dims"};
- // mapping from tensor in original graph to {data, mask, dims} in new graph
- std::unordered_map<Value*, std::vector<Value*>> batch_map;
- // mapping from input in original graph to new input in new graph - used in
- // createClone
- std::unordered_map<Value*, Value*> rn_env;
- std::function<Value*(Value*)> rn_fn = [this](Value* v) {
- return rn_env.at(v);
- };
-
- private:
- std::shared_ptr<Graph> getBatchOperator(
- const std::string& name,
- int64_t input_num = -1);
- void visitAten(Node* n, Block* block, Block* res_block);
- void visitConstant(Node* n, Block* block, Block* res_block);
- void visitNumToTensor(Node* n, Block* block, Block* res_block);
- void visitTensorToNum(Node* n, Block* block, Block* res_block);
- void visitListConstruct(Node* n, Block* block, Block* res_block);
- void visitIf(Node* n, Block* block, Block* res_block);
- void visitLoop(Node* n, Block* block, Block* res_block);
-
- public:
- static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
- batch_operator_table;
- TORCH_API void toBatch(Block* block, Block* res_block);
-};
-
-TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph);
-TORCH_API void initRegisterBatchOpsBindings(PyObject* module);
-} // namespace jit
-} // namespace torch
#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/import_source.h>
#include <torch/csrc/jit/passes/python_print.h>
-#include <torch/csrc/jit/passes/to_batch.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/python_tracer.h>
#include <torch/csrc/jit/script/parser.h>
_unflatten = torch._C._jit_unflatten
_jit_script_compile = torch._C._jit_script_compile
_jit_script_class_compile = torch._C._jit_script_class_compile
-BatchTensor = torch._C._jit.BatchTensor
Future = torch._C.Future
_fork = torch._C.fork
return cls in _jit_internal.weak_types
-def batch(batch_size=1, optimize=True, _frames_up=0):
- def decorator(fn):
- if not _enabled:
- return fn
- import torch.jit.batchop
- mod = script(fn, optimize, _frames_up)
- res_graph = torch.to_batch_graph(mod.graph)
- res_mod = ScriptModule()
- res_mod._create_method_from_graph('forward', res_graph)
-
- def wrapper(*args):
- new_args = []
- for arg in args:
- if isinstance(arg, torch.Tensor):
- arg = BatchTensor(arg, batch_size)
- if isinstance(arg, BatchTensor):
- new_args.extend([arg.get_data(), arg.get_mask(), arg.get_dims()])
- else:
- new_args.append(arg)
- res = res_mod(*new_args)
- assert len(res) % 3 == 0
- if len(res) % 3 != 0:
- raise "non-batched-tensor output is not supported yet"
- result = [BatchTensor(*res[i * 3: i * 3 + 3]) for i in range(len(res) // 3)]
- if len(result) == 1:
- return result[0]
- return result
- wrapper.__doc__ = fn.__doc__
- return wrapper
- return decorator
-
-
# These OrderedDictWrapper classes replace the actual OrderedDicts in
# module with versions that get/set properties inside of script::Module.
# This allows us to reuse most of nn.Module while still storing the
+++ /dev/null
-import torch
-from torch.jit import BatchTensor
-
-
-# TODO: there are some commented raise statements
-# when we support rasie exception in script, we want to check them
-@torch.jit.script
-def batch_tanh(data, mask, dims):
- data = torch.tanh(data)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_sigmoid(data, mask, dims):
- data = torch.sigmoid(data)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_relu(data, mask, dims):
- data = torch.relu(data)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_neg(data, mask, dims):
- data = torch.neg(data)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_neg_scalar(data):
- return torch.neg(data)
-
-
-@torch.jit.script
-def batch_add(data1, mask1, dims1, data2, mask2, dims2, alpha_):
- alpha = float(alpha_)
- data = torch.add(data1, data2, alpha=alpha)
- mask = mask1 * mask2
- dims = dims1.__or__(dims2)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_add_scalar(data, mask, dims, other, alpha_):
- alpha = float(alpha_)
- data = torch.add(data, other.type_as(data), alpha=alpha)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_sub(data1, mask1, dims1, data2, mask2, dims2, alpha_):
- alpha = float(alpha_)
- data = torch.sub(data1, data2, alpha=alpha)
- mask = mask1 * mask2
- dims = dims1.__or__(dims2)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_sub_scalar(data1, data2):
- return data1 - data2
-
-
-@torch.jit.script
-def batch_mul(data1, mask1, dims1, data2, mask2, dims2):
- data = torch.mul(data1, data2)
- mask = mask1 * mask2
- dims = dims1.__or__(dims2)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_mul_scalar(data1, data2):
- return data1 * data2
-
-
-@torch.jit.script
-def batch_div(data, mask, dims, other): # div(batchtensor, scalar)
- data = torch.div(data, other)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_mm(data1, mask1, dims1, data2, mask2, dims2):
- data1 = data1 * mask1.type_as(data1)
- data2 = data2 * mask2.type_as(data2)
- data = torch.bmm(data1, data2)
- mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
- dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_matmul(data1, mask1, dims1, data2, mask2, dims2):
- d1 = data1.dim() - 1
- d2 = data2.dim() - 1
- data1 = data1 * mask1.type_as(data1)
- data2 = data2 * mask2.type_as(data2)
- if d1 == 1:
- data1 = data1.unsqueeze(-2)
- if d2 == 1:
- data2 = data2.unsqueeze(-1)
- data = torch.bmm(data1, data2)
- mask = mask1
- dims = dims1
- if d1 == 1 and d2 == 1:
- # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask).all():
- # raise ValueError("cannot contract non-matching dimensions")
- data = data.squeeze(-1).squeeze(-1)
- mask = mask1.narrow(1, 0, 1).squeeze(-1)
- dims = dims1[:0] # empty tensor
- if d1 == 2 and d2 == 1:
- # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask).all():
- # raise ValueError("cannot contract non-matching dimensions")
- data = data.squeeze(-1)
- mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1).unsqueeze(-1)).squeeze(-1)
- dims = dims1[:1]
- elif d1 == 1 and d2 == 2:
- # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask[:, :, 0]).all():
- # raise ValueError("cannot contract non-matching dimensions")
- data = data.squeeze(-2)
- mask = torch.bmm(mask1.narrow(1, 0, 1).unsqueeze(-2), mask2.narrow(1, 0, 1)).squeeze(-2)
- dims = dims2[1:dims2.size(0)]
- elif d1 == 2 and d2 == 2:
- # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask[:, :, 0]).all():
- # raise ValueError("cannot contract non-matching dimensions")
- mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
- dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
- # else:
- # raise NotImplementedError("matmul not implemented with batches of 3+D tensors")
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_select(data, mask, dims, dim_, index_):
- dim = int(dim_)
- index = int(index_)
- # if dim == 0:
- # raise ValueError("Cannot select 0 dim in BatchTensor")
- data = data.select(dim, index)
- if bool(dims[dim - 1]):
- mask = mask.select(dim, index)
- else:
- mask = mask.select(dim, 0)
- dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_fmod(data, mask, dims, other_):
- other = int(other_)
- data = torch.fmod(data, other)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_zeros_like(data, mask, dims):
- res_data = torch.zeros_like(data)
- return res_data, mask, dims
-
-
-@torch.jit.script
-def batch_index_select(data, mask, dims, dim_, index_data, index_mask, index_dims):
- dim = int(dim_)
- # if dim == 0:
- # raise ValueError("Cannot index_select along 0 dim in BatchTensor")
- batch_size = data.size(0) # TODO maybe index_mask will be used at some point
- res_data = torch.zeros([0])
- res_mask = torch.zeros([0])
- for i in range(batch_size):
- d = data[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
- if bool(dims[dim - 1]):
- m = mask[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
- else:
- m = mask[i].unsqueeze(0)
- if i == 0:
- res_data = d
- res_mask = m
- else:
- res_data = torch.cat((res_data, d), 0)
- res_mask = torch.cat((res_mask, m), 0)
- return res_data, res_mask, dims
-
-
-@torch.jit.script
-def batch_view_as(data, mask, dims, data1, mask1, dims1):
- # if data.size(0) != data1.size(0):
- # raise ValueError("In view_as, tensor and target tensor should have the same batch_size")
- # if not torch.equal(dims, dims1):
- # raise ValueError("In batched view_as, dims and target dims should be the same")
- data = data.view_as(data1)
- mask = mask.view_as(mask1)
- dims = dims1
- return data, mask, dims
-
-
-# assume data, data1, data2 have same size
-@torch.jit.script
-def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
- data = data * mask.type_as(data)
- cond_data = data
- cond_mask = data
- if data.dim() == 1:
- for _ in range(data1.dim() - 1):
- data = data.unsqueeze(data.dim())
- cond_data = data.expand_as(data1)
- cond_mask = data.expand_as(mask1)
- res_data = torch.where(cond_data, data1, data2)
- res_mask = torch.where(cond_mask, mask1, mask2)
- res_dims = dims1.__or__(dims2)
- return res_data, res_mask, res_dims
-
-
-@torch.jit.script
-def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2):
- cond = torch.zeros([1], dtype=torch.uint8)
- res_data = torch.where(cond, data1, data2)
- res_mask = torch.where(cond, mask1, mask2)
- res_dims = torch.where(cond, dims1, dims2)
- return res_data, res_mask, res_dims
-
-
-@torch.jit.script
-def batch_update(batch_data, batch_mask, batch_dims, new_data, new_mask, new_dims):
- data = torch.where(new_mask, new_data, batch_data)
- return data, new_mask, new_dims # TODO: consider whether return new_mask and new_dims
-
-
-@torch.jit.script
-def batch_any(data, mask, dims):
- return torch.gt(torch.sum(data * mask), 0)
-
-
-@torch.jit.script
-def batch_type_as(data, mask, dims, data1, mask1, dims1):
- return data.type_as(data1), mask, dims
-
-
-@torch.jit.script
-def batch_gt(data, mask, dims, data1, mask1, dims1):
- return torch.gt(data, data1), mask * mask1, dims.__or__(dims1)
-
-
-@torch.jit.script
-def batch_gt_scalar(data1, data2):
- return torch.gt(data1, data2)
-
-
-@torch.jit.script
-def batch_gt_one_scalar(data, mask, dims, other_):
- other = float(other_)
- return torch.gt(data, other), mask, dims
-
-
-@torch.jit.script
-def batch_lt(data, mask, dims, data1, mask1, dims1):
- return torch.lt(data, data1), mask * mask1, dims.__or__(dims1)
-
-
-@torch.jit.script
-def batch_eq(data, mask, dims, data1, mask1, dims1):
- return torch.eq(data, data1), mask * mask1, dims.__or__(dims1)
-
-
-@torch.jit.script
-def batch_size(data, mask, dims, dim_):
- dim = int(dim_)
- return data.size(dim)
-
-
-@torch.jit.script
-def batch_dim(data, mask, dims):
- return data.dim()
-
-
-@torch.jit.script
-def batch_squeeze(data, mask, dims, dim_):
- if int(dim_) < 0:
- dim_ = dim_ + data.dim()
- dim = int(dim_)
- # if dim == 0:
- # raise ValueError("cannot do squeeze along batch_dim")
- data = data.squeeze(dim)
- mask = mask.squeeze(dim)
- dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_unsqueeze(data, mask, dims, dim_):
- if int(dim_) < 0:
- dim_ = dim_ + data.dim() + 1
- dim = int(dim_)
- # if dim == 0:
- # raise ValueError("cannot do unsqueeze along batch_dim")
- data = data.unsqueeze(dim)
- mask = mask.unsqueeze(dim)
- dims = torch.cat((dims[:dim], torch.zeros([1], dtype=torch.uint8), dims[dim:dims.size(0)]))
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_argmax(data, mask, dims, dim_, keepdim_):
- dim = int(dim_)
- keepdim = bool(keepdim_)
- # if dim == 0:
- # raise ValueError("cannot do argmax along batch_dim")
- batch_size = data.size(0)
- res_data = torch.zeros([0])
- for i in range(batch_size):
- if bool(dims[dim - 1]):
- if dim - 1 != 0:
- m = mask[i].transpose(0, dim - 1)
- else:
- m = mask[i]
- valid_num = m.sum(0, keepdim=True)
- while(valid_num.dim() >= 1):
- valid_num = valid_num[0]
- d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
- else:
- d = data[i].unsqueeze(0)
- d = d.argmax(dim, keepdim)
- if i == 0:
- res_data = d
- else:
- res_data = torch.cat([res_data, d], 0)
- if keepdim:
- mask = mask
- else:
- mask = mask.select(dim, 0)
- dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
- return res_data, mask, dims
-
-
-@torch.jit.script
-def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_):
- k = int(k_)
- dim = int(dim_)
- largest = bool(largest_)
- sorted = bool(sorted_)
- # if dim == 0:
- # raise ValueError("cannot do topk along batch_dim")
- batch_size = data.size(0)
- res_data = torch.zeros([0])
- res_index = torch.zeros([0])
- for i in range(batch_size):
- if bool(dims[dim - 1]):
- if dim - 1 != 0:
- m = mask[i].transpose(0, dim - 1)
- else:
- m = mask[i]
- valid_num = m.sum(0, keepdim=True)
- while(valid_num.dim() >= 1):
- valid_num = valid_num[0]
- d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
- else:
- d = data[i].unsqueeze(0)
- d, idx = d.topk(k, dim, largest, sorted)
- if i == 0:
- res_data = d
- res_index = idx
- else:
- res_data = torch.cat([res_data, d], 0)
- res_index = torch.cat([res_index, idx], 0)
- if bool(dims[dim - 1]):
- mask = mask.narrow(dim, 0, k)
- return res_data, mask, dims, res_index, mask, dims
-
-
-@torch.jit.script
-def batch_softmax(data, mask, dims, dim_):
- dim = int(dim_)
- # if dim == 0:
- # raise ValueError("cannot do softmax along batch_dim")
- batch_size = data.size(0)
- max_len = data.size(dim)
- res_data = torch.zeros([0])
- for i in range(batch_size):
- if bool(dims[dim - 1]):
- if dim - 1 != 0:
- m = mask[i].transpose(0, dim - 1)
- else:
- m = mask[i]
- valid_num = m.sum(0, keepdim=True)
- while(valid_num.dim() >= 1):
- valid_num = valid_num[0]
- valid_num = int(valid_num)
- d = data[i].unsqueeze(0).narrow(dim, 0, valid_num).softmax(dim)
- if valid_num < max_len:
- d = torch.cat([d, data[i].unsqueeze(0).narrow(dim, valid_num, max_len - valid_num)], dim)
- else:
- d = data[i].unsqueeze(0).softmax(dim)
- if i == 0:
- res_data = d
- else:
- res_data = torch.cat([res_data, d], 0)
- return res_data, mask, dims
-
-
-# size argument in dynamic dimension has to be -1
-# in static dimension, size has to be specified, -1 is not supported
-@torch.jit.script
-def batch_view(data, mask, dims, sizes):
- batch_size = data.size(0)
- # if(sizes[0] != batch_size and sizes[0] != -1 and sizes[0] != 1):
- # raise "first dim in view must be 1, -1, or batch size"
- # for i in range(dims.size(0)):
- # if dims[0] == 1 and sizes[i + 1] != -1:
- # raise "size argument in dynamic dimension has to be -1"
- sizes = sizes.type_as(torch.ones([1], dtype=torch.int))
- data_sizes_ = torch.cat([torch.ones([1], dtype=torch.int) * batch_size, sizes.narrow(0, 1, sizes.size(0) - 1)], 0)
- data_sizes = data_sizes_._tensor_to_list()
- res_data = data.view(data_sizes)
- mask_sizes_ = data_sizes_.narrow(0, 0, 1)
- res_dims = data_sizes_.narrow(0, 0, 1)
- for i_ in range(sizes.size(0) - 1):
- i = i_ + 1
- if bool(sizes[i] == -1):
- cur_size_ = mask.size(i)
- cur_dim = 1
- else:
- cur_size_ = 1
- cur_dim = 0
- mask_sizes_ = torch.cat([mask_sizes_, torch.ones([1], dtype=torch.int) * cur_size_])
- res_dims = torch.cat([res_dims, torch.ones([1], dtype=torch.int) * cur_dim])
- mask_sizes = mask_sizes_._tensor_to_list()
- res_mask = mask.view(mask_sizes)
- return res_data, res_mask, res_dims.narrow(0, 1, res_dims.size(0) - 1).type_as(dims)
-
-
-@torch.jit.script
-def batch_cat2(data1, mask1, dims1, data2, mask2, dims2, dim_):
- dim = int(dim_)
- data = torch.cat([data1, data2], dim)
- if bool(dims1[dim - 1]):
- mask = torch.cat([mask1, mask2], dim)
- else:
- mask = mask1
- return data, mask, dims1
-
-
-@torch.jit.script
-def batch_cat3(data1, mask1, dims1, data2, mask2, dims2, data3, mask3, dims3, dim_):
- dim = int(dim_)
- data = torch.cat([data1, data2, data3], dim)
- if bool(dims1[dim - 1]):
- mask = torch.cat([mask1, mask2, mask3], dim)
- else:
- mask = mask1
- return data, mask, dims1
-
-
-@torch.jit.script
-def batch_narrow(data, mask, dims, dimension_, start_, length_):
- dimension = int(dimension_)
- start = int(start_)
- length = int(length_)
- # if dimension == 0:
- # raise ValueError("cannot do narrow along batch_dim")
- data = data.narrow(dimension, start, length)
- if bool(dims[dimension - 1]):
- mask = mask.narrow(dimension, start, length)
- else:
- mask = mask.narrow(dimension, 0, 1)
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_sum(data, mask, dims):
- data = data * mask.type_as(data)
- for _ in range(dims.size(0)):
- data = data.sum(1)
- mask = torch.ones([data.size(0)], dtype=torch.uint8)
- dims = dims[:0] # empty tensor
- return data, mask, dims
-
-
-@torch.jit.script
-def batch_from_scalar_tensor(data):
- data = data.unsqueeze(0)
- mask = torch.ones([1], dtype=torch.uint8)
- dims = torch.zeros([0], dtype=torch.uint8)
- return data, mask, dims
-
-torch.register_batch_operator("tanh", batch_tanh.graph)
-torch.register_batch_operator("sigmoid", batch_sigmoid.graph)
-torch.register_batch_operator("relu", batch_relu.graph)
-torch.register_batch_operator("neg", batch_neg.graph)
-torch.register_batch_operator("neg", batch_neg_scalar.graph)
-torch.register_batch_operator("add", batch_add.graph)
-torch.register_batch_operator("add", batch_add_scalar.graph)
-torch.register_batch_operator("sub", batch_sub.graph)
-torch.register_batch_operator("sub", batch_sub_scalar.graph)
-torch.register_batch_operator("mul", batch_mul.graph)
-torch.register_batch_operator("mul", batch_mul_scalar.graph)
-torch.register_batch_operator("div", batch_div.graph)
-torch.register_batch_operator("matmul", batch_matmul.graph)
-torch.register_batch_operator("mm", batch_mm.graph)
-torch.register_batch_operator("fmod", batch_fmod.graph)
-torch.register_batch_operator("zeros_like", batch_zeros_like.graph)
-torch.register_batch_operator("select", batch_select.graph)
-torch.register_batch_operator("index_select", batch_index_select.graph)
-torch.register_batch_operator("view_as", batch_view_as.graph)
-torch.register_batch_operator("where", batch_where.graph)
-torch.register_batch_operator("where", batch_where_scalar.graph)
-torch.register_batch_operator("update", batch_update.graph)
-torch.register_batch_operator("any", batch_any.graph)
-torch.register_batch_operator("type_as", batch_type_as.graph)
-torch.register_batch_operator("gt", batch_gt.graph)
-torch.register_batch_operator("gt", batch_gt_scalar.graph)
-torch.register_batch_operator("gt", batch_gt_one_scalar.graph)
-torch.register_batch_operator("lt", batch_lt.graph)
-torch.register_batch_operator("eq", batch_eq.graph)
-torch.register_batch_operator("size", batch_size.graph)
-torch.register_batch_operator("dim", batch_dim.graph)
-torch.register_batch_operator("squeeze", batch_squeeze.graph)
-torch.register_batch_operator("unsqueeze", batch_unsqueeze.graph)
-torch.register_batch_operator("argmax", batch_argmax.graph)
-torch.register_batch_operator("topk", batch_topk.graph)
-torch.register_batch_operator("softmax", batch_softmax.graph)
-torch.register_batch_operator("view", batch_view.graph)
-torch.register_batch_operator("cat", batch_cat2.graph)
-torch.register_batch_operator("cat", batch_cat3.graph)
-torch.register_batch_operator("narrow", batch_narrow.graph)
-torch.register_batch_operator("sum", batch_sum.graph)
-torch.register_batch_operator("batch_from_scalar_tensor", batch_from_scalar_tensor.graph)