Delete batch tensor (#18575)
authorElias Ellison <eellison@fb.com>
Fri, 29 Mar 2019 06:07:45 +0000 (23:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 06:13:27 +0000 (23:13 -0700)
Summary:
Deleting batch tensor since we are no longer maintaining the project and keeping it functional is blocking other improvements.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18575

Differential Revision: D14671126

Pulled By: eellison

fbshipit-source-id: b42d5b699c4d12171ed95e6d3a977532167f0d2c

test/test_docs_coverage.py
test/test_jit.py
torch/CMakeLists.txt
torch/csrc/jit/batched/BatchTensor.cpp [deleted file]
torch/csrc/jit/batched/BatchTensor.h [deleted file]
torch/csrc/jit/init.cpp
torch/csrc/jit/passes/to_batch.cpp [deleted file]
torch/csrc/jit/passes/to_batch.h [deleted file]
torch/csrc/jit/script/init.cpp
torch/jit/__init__.py
torch/jit/batchop.py [deleted file]

index 811221a..3b565c3 100644 (file)
@@ -36,8 +36,7 @@ class TestDocCoverage(unittest.TestCase):
         whitelist = {
             # below are some jit functions
             'wait', 'fork', 'parse_type_comment', 'import_ir_module',
-            'to_batch_graph', 'import_ir_module_from_buffer',
-            'register_batch_operator', 'merge_type_from_type_comment',
+            'import_ir_module_from_buffer', 'merge_type_from_type_comment',
 
             # below are symbols mistakely binded to torch.*, but should
             # go to torch.nn.functional.* instead
index b1b5c98..02fe570 100644 (file)
@@ -48,7 +48,6 @@ from copy import deepcopy
 import random
 from typing import List, Dict, Optional, Tuple
 from torch.jit.frontend import NotSupportedError
-from torch.jit import BatchTensor
 from torch import Tensor
 from torch.jit.annotations import BroadcastingList2, BroadcastingList3
 
@@ -2465,548 +2464,6 @@ class TestJit(JitTestCase):
         self.assertTrue(tested_blocks)
 
 
-class TestBatched(TestCase):
-    # generate random examples and create an batchtensor with them
-    def rand_batch(self, *dims):
-        dims = [dim for dim in dims if dim != ()]
-        xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
-                         requires_grad=True) for i in range(dims[0])]
-        xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
-        return xs, xb
-
-    def test_create_batchtensor(self):
-        # create from tensorlist
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
-        self.assertEqual(xs, batch.examples())
-        # create from data, mask, dims
-        batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
-        self.assertEqual(xs, batch2.examples())
-        # expand a tensor to a batchtensor given batch_size
-        xs = torch.rand(3, 4, 5)
-        batch3 = BatchTensor(xs, 2)
-        xs = xs.unsqueeze(0)
-        self.assertEqual([xs, xs], batch3.examples())
-
-    def test_batch_elementwise_unary(self):
-        @torch.jit.batch(batch_size=4)
-        def tanh(a):
-            return torch.tanh(a)
-
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        res_batch = tanh(batch)
-        res = [torch.tanh(xs[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_elementwise_binary(self):
-        @torch.jit.batch(batch_size=4)
-        def add(a, b):
-            return a + b
-
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        xs2, batch2 = xs, batch
-        res_batch = add(batch, batch2)
-        res = [torch.add(xs[j], xs2[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        # test broadcast
-        xs, batch = self.rand_batch(4, (False, 3), (False, 2))
-        b = torch.rand(3, 2)
-        res_batch = add(batch, b)
-        res = [torch.add(xs[j], b) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_mm(self):
-        @torch.jit.batch(batch_size=4)
-        def mm(a, b):
-            return torch.mm(a, b)
-
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
-        res_batch = mm(batch, batch2)
-        res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        # test broadcast
-        b = torch.rand(2, 4)
-        res_batch = mm(batch, b)
-        res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_matmul(self):
-        @torch.jit.batch(batch_size=4)
-        def matmul(a, b):
-            return torch.matmul(a, b)
-
-        def matmul_test(xs, batch, xs2, batch2):
-            ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
-            ybs = matmul(batch, batch2)
-            self.assertEqual(ys, ybs.examples())
-
-        # 1 dimension * 1 dimension
-        xs, batch = self.rand_batch(4, (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2))
-        matmul_test(xs, batch, xs2, batch2)
-        # 1 dimension * 2 dimension
-        xs, batch = self.rand_batch(4, (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
-        matmul_test(xs, batch, xs2, batch2)
-        # 2 dimension * 1 dimensions
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2))
-        matmul_test(xs, batch, xs2, batch2)
-        # 2 dimension * 2 dimension
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
-        matmul_test(xs, batch, xs2, batch2)
-
-    def test_batch_select(self):
-        @torch.jit.batch(batch_size=4)
-        def select(x):
-            return torch.select(x, 1, 0)
-
-        xs, batch = self.rand_batch(4, (True, 3), (True, 2))
-        res_batch = select(batch)
-        res = [torch.select(xs[j], 1, 0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        xs, batch = self.rand_batch(4, (False, 3), (True, 2))
-        res_batch = select(batch)
-        res = [torch.select(xs[j], 1, 0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_index_select(self):
-        @torch.jit.batch(batch_size=4)
-        def index_select(x, ind):
-            return x.index_select(1, ind)
-
-        xs, batch = self.rand_batch(4, (False, 5), (True, 2))
-        ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
-        ind_batch = BatchTensor(ind, torch.tensor([]).byte())
-        res_batch = index_select(batch, ind_batch)
-        res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_where(self):
-        @torch.jit.batch(batch_size=4)
-        def where(c, a, b):
-            return torch.where(c, a, b)
-
-        xs, batch = self.rand_batch(4, (False, 3), (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
-
-        dims = [4, (False, 3), (False, 2)]
-        xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
-        batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
-
-        res_batch = where(batch_cond, batch, batch2)
-        res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_argmax(self):
-        @torch.jit.batch(batch_size=4)
-        def argmax(a):
-            return torch.argmax(a, 1)
-
-        xs, batch = self.rand_batch(4, (True, 5), (True, 6))
-        res_batch = argmax(batch)
-        res = [torch.argmax(xs[j], 1) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        @torch.jit.batch(batch_size=4)
-        def argmax(a):
-            return torch.argmax(a, 1, False)
-
-        res_batch = argmax(batch)
-        res = [torch.argmax(xs[j], 1, False) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_topk(self):
-        @torch.jit.batch(batch_size=4)
-        def topk(a):
-            return torch.topk(a, 3, 1)
-
-        xs, batch = self.rand_batch(4, (False, 5), (True, 6))
-
-        # along static dim
-        res_batch = topk(batch)
-        res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
-        res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
-        self.assertEqual(res, res_batch[0].examples())
-        self.assertEqual(res_idx, res_batch[1].examples())
-
-        @torch.jit.batch(batch_size=4)
-        def topk(a):
-            return torch.topk(a, 1, 2)
-
-        # along dynamic dim
-        res_batch = topk(batch)
-        res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
-        res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
-        self.assertEqual(res, res_batch[0].examples())
-        self.assertEqual(res_idx, res_batch[1].examples())
-
-    def test_batch_softmax(self):
-        @torch.jit.batch(batch_size=4)
-        def softmax(a):
-            return torch.softmax(a, 1)
-
-        xs, batch = self.rand_batch(4, (False, 5), (True, 6))
-
-        # along static dim
-        res_batch = softmax(batch)
-        res = [torch.softmax(xs[j], 1) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        @torch.jit.batch(batch_size=4)
-        def softmax(a):
-            return torch.softmax(a, 2)
-
-        # along dynamic dim
-        res_batch = softmax(batch)
-        res = [torch.softmax(xs[j], 2) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_view(self):
-        @torch.jit.batch(batch_size=4)
-        def view(a):
-            return a.view([4, -1, 3])
-
-        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
-        res_batch = view(batch)
-        res = [xs[j].view([1, -1, 3]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_cat(self):
-        @torch.jit.batch(batch_size=4)
-        def cat2(a, b):
-            return torch.cat([a, b], 2)
-
-        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
-        xs2, batch2 = xs, batch
-        res_batch = cat2(batch, batch2)
-        res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_sum(self):
-        @torch.jit.batch(batch_size=4)
-        def batch_sum(a):
-            return a.sum()
-
-        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
-        res_batch = batch_sum(batch)
-        res = [xs[j].sum().unsqueeze(0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_if_else(self):
-        def single_if(a, b):
-            if bool(a > b):
-                a = a + b
-            else:
-                a = a - b
-            return a
-
-        batch_if = torch.jit.batch(batch_size=4)(single_if)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_if(batch_a, batch_b)
-        res = [single_if(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_if = torch.jit.script(single_if)
-        torch.to_batch_graph(script_if.graph)
-
-    def test_if_else_with_scalar(self):
-        def single_if(a, b):
-            if bool(a > 0.1):
-                a = a + b
-            else:
-                a = a - b
-            return a
-
-        batch_if = torch.jit.batch(batch_size=4)(single_if)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_if(batch_a, batch_b)
-        res = [single_if(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_if = torch.jit.script(single_if)
-        torch.to_batch_graph(script_if.graph)
-
-    def test_if_noelse(self):
-        def single_if(a, b):
-            if bool(a > b):
-                a = a + b
-            return a
-
-        batch_if = torch.jit.batch(batch_size=4)(single_if)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_if(batch_a, batch_b)
-        res = [single_if(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_if = torch.jit.script(single_if)
-        torch.to_batch_graph(script_if.graph)
-
-    def test_if_noelse_with_scalar(self):
-        def single_if(a, b):
-            if bool(a > 0.1):
-                a = a + b
-            return a
-
-        batch_if = torch.jit.batch(batch_size=4)(single_if)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_if(batch_a, batch_b)
-        res = [single_if(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_if = torch.jit.script(single_if)
-        torch.to_batch_graph(script_if.graph)
-
-    def test_while(self):
-        def single_while(a, b):
-            while bool(a > b):
-                a = a - b
-            return a
-
-        batch_while = torch.jit.batch(batch_size=4)(single_while)
-
-        a, batch_a = self.rand_batch(4, ())
-        b = [torch.abs(torch.rand(1)) for i in range(4)]
-        batch_b = BatchTensor(b, torch.tensor([]).byte())
-        res_batch = batch_while(batch_a, batch_b)
-        res = [single_while(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_while = torch.jit.script(single_while)
-        torch.to_batch_graph(script_while.graph)
-
-    def test_for(self):
-        def single_for(x, y):
-            for _ in range(10):
-                x = x + y
-            return x
-
-        batch_for = torch.jit.batch(batch_size=4)(single_for)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_for(batch_a, batch_b)
-        res = [single_for(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_for = torch.jit.script(single_for)
-        torch.to_batch_graph(script_for.graph)
-
-    def test_lstm(self):
-        def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
-            for i in range(x_all.size(1)):
-                x = x_all.select(1, i)
-                i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
-                f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
-                o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
-                # activations
-                i_t = torch.sigmoid(i_t)
-                f_t = torch.sigmoid(f_t)
-                o_t = torch.sigmoid(o_t)
-                # cell computations
-                c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
-                c_t = torch.tanh(c_t)
-                c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
-                h_t = torch.mul(o_t, torch.tanh(c_t))
-                h = h_t
-                c = c_t
-            return h
-
-        LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
-
-        batch_size, input_size, hidden_size = 4, 3, 2
-        xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
-        hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
-        cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
-
-        # input to hidden weights
-        w_xi = torch.rand(input_size, hidden_size)
-        w_xf = torch.rand(input_size, hidden_size)
-        w_xo = torch.rand(input_size, hidden_size)
-        w_xc = torch.rand(input_size, hidden_size)
-        # hidden to hidden weights
-        w_hi = torch.rand(hidden_size, hidden_size)
-        w_hf = torch.rand(hidden_size, hidden_size)
-        w_ho = torch.rand(hidden_size, hidden_size)
-        w_hc = torch.rand(hidden_size, hidden_size)
-        # bias terms
-        b_i = torch.rand(hidden_size)
-        b_f = torch.rand(hidden_size)
-        b_o = torch.rand(hidden_size)
-        b_c = torch.rand(hidden_size)
-
-        ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
-                   w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
-        ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
-                         w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
-        self.assertEqual(ys, ybs.examples())
-
-    @slowTest
-    def test_greedy_search(self):
-        def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
-                   b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
-            iter_count = torch.zeros_like(iter_num)
-            while bool(iter_count < iter_num):
-                iter_count = iter_count + 1
-                # LSTM Cell
-                i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
-                f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
-                o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
-                # activations
-                i_t = torch.sigmoid(i_t)
-                f_t = torch.sigmoid(f_t)
-                o_t = torch.sigmoid(o_t)
-                # cell computations
-                c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
-                c_t = torch.tanh(c_t)
-                c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
-                h_t = torch.mul(o_t, torch.tanh(c_t))
-                h = h_t
-                c = c_t
-                # calculate feature with max probability
-                s_t = torch.matmul(h_t, w_hs) + b_s
-                p_t = torch.softmax(s_t, 1)
-                i_t = torch.argmax(p_t, 1)
-                x = embed.index_select(1, i_t).squeeze(1)
-            return h
-
-        greedy_batch = torch.jit.batch(batch_size=4)(greedy)
-
-        batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
-        xs, batch = self.rand_batch(batch_size, (False, input_size))
-        hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
-        cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
-        embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
-        iter_num = [torch.randint(2, 5, (1,)) for i in range(batch_size)]
-        iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
-
-        # input to hidden weights
-        w_xi = torch.rand(input_size, hidden_size)
-        w_xf = torch.rand(input_size, hidden_size)
-        w_xo = torch.rand(input_size, hidden_size)
-        w_xc = torch.rand(input_size, hidden_size)
-        # hidden to hidden weights
-        w_hi = torch.rand(hidden_size, hidden_size)
-        w_hf = torch.rand(hidden_size, hidden_size)
-        w_ho = torch.rand(hidden_size, hidden_size)
-        w_hc = torch.rand(hidden_size, hidden_size)
-        # bias terms
-        b_i = torch.rand(hidden_size)
-        b_f = torch.rand(hidden_size)
-        b_o = torch.rand(hidden_size)
-        b_c = torch.rand(hidden_size)
-        # hidden to vocab weights, bias
-        w_hs = torch.rand(hidden_size, vocab_size)
-        b_s = torch.rand(vocab_size)
-
-        ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc,
-                     w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j]) for j in range(batch_size)]
-        ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
-                           w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
-        self.assertEqual(ys, ybs.examples())
-
-    @slowTest
-    def test_beam_search(self):
-        def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
-                 b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
-            k = 5
-            vocab_size = embed.size(1)
-            iter_count = torch.zeros_like(iter_num)
-            max_len = idx.size(2)
-            while bool(iter_count < iter_num):
-                iter_count = iter_count + 1
-                # LSTM Cell
-                i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
-                f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
-                o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
-                # activations
-                i_t = torch.sigmoid(i_t)
-                f_t = torch.sigmoid(f_t)
-                o_t = torch.sigmoid(o_t)
-                # cell computations
-                c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
-                c_t = torch.tanh(c_t)
-                c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
-                h_t = torch.mul(o_t, torch.tanh(c_t))
-                h = h_t
-                c = c_t
-                # calculate features with max probability
-                s_t = torch.matmul(h_t, w_hs) + b_s
-                s_t = s_t.view([1, s_t.size(1) * s_t.size(2)])
-                p_t = torch.softmax(s_t, 1)
-                prob_t, idx_t = torch.topk(p_t, k, 1)
-                if(int(idx_t.dim()) > 1):
-                    idx_t_tmp = idx_t.squeeze(0)
-                else:
-                    idx_t_tmp = idx_t
-                new_y = torch.fmod(idx_t_tmp, vocab_size)
-                pre_y = idx_t_tmp / vocab_size
-                x = embed.index_select(1, new_y)
-                h = h_t.index_select(1, pre_y)
-                c = c_t.index_select(1, pre_y)
-                iter = int(iter_count[0])
-                idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y),
-                                 torch.fmod(idx_t, vocab_size).unsqueeze(-1),
-                                 idx.narrow(2, iter, max_len - iter)], 2)
-                idx = idx.narrow(2, 0, max_len)
-            return idx
-
-        beam_batch = torch.jit.batch(batch_size=4)(beam)
-
-        k = 5
-        batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
-        max_len = 5
-        xs, batch = self.rand_batch(batch_size, (False, 1), (False, input_size))
-        hx, h_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
-        cx, c_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
-        embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
-        iter_num = [torch.randint(2, max_len + 1, (1,)) for i in range(batch_size)]
-        iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
-
-        # input to hidden weights
-        w_xi = torch.rand(input_size, hidden_size)
-        w_xf = torch.rand(input_size, hidden_size)
-        w_xo = torch.rand(input_size, hidden_size)
-        w_xc = torch.rand(input_size, hidden_size)
-        # hidden to hidden weights
-        w_hi = torch.rand(hidden_size, hidden_size)
-        w_hf = torch.rand(hidden_size, hidden_size)
-        w_ho = torch.rand(hidden_size, hidden_size)
-        w_hc = torch.rand(hidden_size, hidden_size)
-        # bias terms
-        b_i = torch.rand(1, hidden_size)
-        b_f = torch.rand(1, hidden_size)
-        b_o = torch.rand(1, hidden_size)
-        b_c = torch.rand(1, hidden_size)
-        # hidden to vocab weights, bias
-        w_hs = torch.rand(hidden_size, vocab_size)
-        b_s = torch.rand(1, vocab_size)
-
-        idx_batch = torch.jit.BatchTensor(torch.zeros([batch_size, k, max_len], dtype=torch.long),
-                                          torch.zeros([batch_size, 1, max_len]).byte(),
-                                          torch.tensor([0, 1]).byte())
-        idx = [torch.zeros([1, k, max_len], dtype=torch.long) for _ in range(batch_size)]
-
-        ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
-                   b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j]))
-              for j in range(batch_size)]
-        ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
-                         w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
-        self.assertEqual(ys, ybs.examples())
-
-
 def execWrapper(code, glob, loc):
     if PY2:
         exec(code) in glob, loc
index 5db5260..4e6fe9a 100644 (file)
@@ -520,13 +520,11 @@ if (BUILD_PYTHON)
     ${TORCH_SRC_DIR}/csrc/autograd/python_variable.cpp
     ${TORCH_SRC_DIR}/csrc/autograd/python_variable_indexing.cpp
     ${TORCH_SRC_DIR}/csrc/byte_order.cpp
-    ${TORCH_SRC_DIR}/csrc/jit/batched/BatchTensor.cpp
     ${TORCH_SRC_DIR}/csrc/jit/init.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/onnx.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/peephole.cpp
-    ${TORCH_SRC_DIR}/csrc/jit/passes/to_batch.cpp
     ${TORCH_SRC_DIR}/csrc/jit/python_arg_flatten.cpp
     ${TORCH_SRC_DIR}/csrc/jit/python_interpreter.cpp
     ${TORCH_SRC_DIR}/csrc/jit/python_ir.cpp
diff --git a/torch/csrc/jit/batched/BatchTensor.cpp b/torch/csrc/jit/batched/BatchTensor.cpp
deleted file mode 100644 (file)
index 7d709a6..0000000
+++ /dev/null
@@ -1,98 +0,0 @@
-#include <torch/csrc/jit/batched/BatchTensor.h>
-
-namespace torch {
-namespace jit {
-
-BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims) {
-  if (data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1) {
-    throw std::runtime_error(
-        "malformed MaskedBatch with data.dim(): " + std::to_string(data.dim()) +
-        ", mask.dim(): " + std::to_string(mask.dim()) +
-        ", dims.size(0): " + std::to_string(dims.size(0)));
-  }
-  this->data = std::move(data);
-  this->mask = std::move(mask);
-  this->dims = std::move(dims);
-}
-
-BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size) {
-  dims = at::empty(data.dim(), data.options().dtype(at::kByte));
-  dims.fill_(0);
-  std::vector<int64_t> sizes(data.dim() + 1, -1);
-  sizes[0] = batch_size;
-  this->data = data.unsqueeze(0).expand(sizes);
-  std::vector<int64_t> mask_sizes(data.dim() + 1, 1);
-  mask_sizes[0] = batch_size;
-  mask = at::empty(mask_sizes, data.options().dtype(at::kByte));
-  mask.fill_(1);
-}
-
-BatchTensor::BatchTensor(
-    const std::vector<at::Tensor>& datalist,
-    at::Tensor dims) {
-  auto bs = datalist.size();
-  std::vector<int64_t> sizes(dims.size(0) + 1, 0),
-      mask_sizes(dims.size(0) + 1, 0);
-  sizes[0] = bs;
-  mask_sizes[0] = bs;
-  for (int64_t i = 1; i < dims.size(0) + 1; i++) {
-    for (const auto& x : datalist) {
-      sizes[i] = std::max(sizes[i], x.size(i));
-    }
-    mask_sizes[i] = *dims[i - 1].data<uint8_t>() ? sizes[i] : 1;
-  }
-  data = at::empty(sizes, datalist[0].options());
-  data.fill_(0);
-  mask = at::empty(mask_sizes, datalist[0].options().dtype(at::kByte));
-  mask.fill_(0);
-  for (std::size_t i = 0; i < datalist.size(); i++) {
-    auto data_item = data.narrow(0, i, 1);
-    auto mask_item = mask.narrow(0, i, 1);
-    for (int64_t j = 0; j < dims.size(0); j++) {
-      if (*dims[j].data<uint8_t>()) {
-        data_item = data_item.narrow(j + 1, 0, datalist[i].size(j + 1));
-        mask_item = mask_item.narrow(j + 1, 0, datalist[i].size(j + 1));
-      }
-    }
-    data_item += datalist[i];
-    mask_item.fill_(1);
-  }
-  this->dims = std::move(dims);
-}
-
-std::vector<at::Tensor> BatchTensor::examples() {
-  std::vector<at::Tensor> result;
-  // calculate number of valid entries in dth dimension of data
-  auto mask_sum = [](at::Tensor data, int d) -> int64_t {
-    data = data.sum(d, /*keepdim=*/true);
-    while (data.dim() >= 1)
-      data = data[0];
-    return *data.data<int64_t>();
-  };
-  for (int64_t i = 0; i < data.size(0); i++) {
-    auto data_tmp = data.narrow(0, i, 1);
-    for (int64_t d = 0; d < dims.size(0); d++) {
-      if (*dims[d].data<uint8_t>()) {
-        data_tmp = data_tmp.narrow(d + 1, 0, mask_sum(mask[i], d));
-      }
-    }
-    result.push_back(data_tmp);
-  }
-  return result;
-}
-
-void initBatchTensorBindings(PyObject* module) {
-  auto m = py::handle(module).cast<py::module>();
-  auto jit = m.def_submodule("_jit");
-  py::class_<BatchTensor>(jit, "BatchTensor")
-      .def(py::init<at::Tensor, at::Tensor, at::Tensor>())
-      .def(py::init<at::Tensor, int64_t>())
-      .def(py::init<std::vector<at::Tensor>, at::Tensor>())
-      .def("examples", &BatchTensor::examples)
-      .def("get_data", &BatchTensor::get_data)
-      .def("get_mask", &BatchTensor::get_mask)
-      .def("get_dims", &BatchTensor::get_dims);
-}
-
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/batched/BatchTensor.h b/torch/csrc/jit/batched/BatchTensor.h
deleted file mode 100644 (file)
index b1b49e8..0000000
+++ /dev/null
@@ -1,57 +0,0 @@
-#pragma once
-#include <ATen/ATen.h>
-#include <ATen/Tensor.h>
-#include <torch/csrc/jit/pybind.h>
-#include <iostream>
-#include <vector>
-
-namespace torch {
-namespace jit {
-struct BatchTensor {
- public:
-  BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims);
-  // expand a tensor to a batchtensor given batch_size
-  BatchTensor(const at::Tensor& data, int64_t batch_size);
-  BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dims);
-  const char* toString() const {
-    return "BatchTensor";
-  }
-  at::IntArrayRef sizes() const {
-    return data.sizes();
-  }
-  int64_t dim() const {
-    return data.dim();
-  }
-  std::vector<at::Tensor> examples();
-  at::Tensor get_data() {
-    return data;
-  }
-  at::Tensor get_mask() {
-    return mask;
-  }
-  at::Tensor get_dims() {
-    return dims;
-  }
-
- public:
-  // data is a Tensor whose size is the batch size in the batch dimension,
-  // the size of all examples in static dimensions,
-  // and at least as large as the largest example in the batch in dynamic
-  // dimensions.
-  at::Tensor data;
-  // mask is a Tensor whose size is the batch size in the batch dimension,
-  // one in static dimensions,
-  // and at least as large as the largest example in the batch in dynamic
-  // dimensions. Each entry in the mask corresponds to one or more entries in
-  // the data array (singleton, i.e., static, dimensions are broadcasted), with
-  // a one in the mask denoting that the corresponding data entries represent
-  // valid, meaningful data and a zero denoting that they do not.
-  at::Tensor mask;
-  // dims is a 1-dimensional tensor with a bool for each non-batch dimension,
-  // representing whether that dimension is static (False) or dynamic (True).
-  at::Tensor dims;
-};
-
-void initBatchTensorBindings(PyObject* module);
-} // namespace jit
-} // namespace torch
index f46dd9f..bd66058 100644 (file)
@@ -2,7 +2,6 @@
 #include <torch/csrc/utils/pybind.h>
 
 #include <torch/csrc/jit/argument_spec.h>
-#include <torch/csrc/jit/batched/BatchTensor.h>
 #include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/kernel_cache.h>
@@ -31,7 +30,6 @@
 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
-#include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
 #include <torch/csrc/jit/pybind_utils.h>
 #include <torch/csrc/jit/python_arg_flatten.h>
@@ -483,8 +481,6 @@ void initJITBindings(PyObject* module) {
   tracer::initPythonTracerBindings(module);
   script::initTreeViewBindings(module);
   script::initJitScriptBindings(module);
-  initBatchTensorBindings(module);
-  initRegisterBatchOpsBindings(module);
 }
 
 } // namespace jit
diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp
deleted file mode 100644 (file)
index 445e02b..0000000
+++ /dev/null
@@ -1,610 +0,0 @@
-#include <torch/csrc/jit/passes/to_batch.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
-#include <torch/csrc/jit/script/compiler.h>
-
-namespace torch {
-namespace jit {
-
-std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
-    ToBatch::batch_operator_table;
-
-std::shared_ptr<Graph> ToBatch::getBatchOperator(
-    const std::string& name,
-    int64_t num_inputs) {
-  if (batch_operator_table.find(name) == batch_operator_table.end()) {
-    throw std::runtime_error(
-        "function " + name + " is not supported in batched tensor yet");
-  }
-  auto ops = batch_operator_table.at(name);
-  if (num_inputs == -1) // default function
-    return ops[0];
-  for (auto op : ops) {
-    if (size_t(num_inputs) == op->inputs().size())
-      return op;
-  }
-  throw std::runtime_error(
-      "function " + name + " with " + std::to_string(num_inputs) +
-      " inputs is not supported in batched tensor yet");
-}
-
-std::vector<Value*> inlineUnpackedCallTo(
-    Graph& g,
-    Graph& callee,
-    ArrayRef<Value*> inputs) {
-  return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true);
-}
-
-// replace aten operator node with BatchTensor operator graph
-void ToBatch::visitAten(Node* n, Block* block, Block* res_block) {
-  auto res_graph = res_block->owningGraph();
-  auto func_name = std::string(n->kind().toUnqualString());
-  std::vector<Value*> new_inputs;
-  for (Value* input : n->inputs()) {
-    if (rn_env.find(input) == rn_env.end()) { // non-tensor input
-      auto new_input = batch_map.at(input);
-      new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end());
-    } else { // batched tensor input
-      new_inputs.push_back(rn_env.at(input));
-    }
-  }
-
-  // transform scalar to tensor before pass to batch operator script
-  for (auto& input : new_inputs) {
-    if (input->type() == IntType::get() || input->type() == FloatType::get() ||
-        input->type() == BoolType::get()) {
-      auto to_tensor_node = res_graph->createNumToTensor(input);
-      res_graph->insertNode(to_tensor_node);
-      input = to_tensor_node->output();
-    }
-  }
-
-  auto batch_graph = getBatchOperator(func_name, new_inputs.size());
-  auto outputs =
-      inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
-
-  // Assume all outputs from inlined operator implementation are in the triple
-  // form batched tensor or just a single non-tensor.
-  if (outputs.size() == 1) {
-    // if previous output is scalar, transform new output back to scalar from
-    // dynamic
-    TypePtr orig_type = n->outputs()[0]->type();
-    if (!orig_type->isSubtypeOf(outputs[0]->type())) {
-      Symbol op;
-      if (orig_type == IntType::get()) {
-        op = prim::Int;
-      } else if (orig_type == FloatType::get()) {
-        op = prim::Float;
-      } else if (orig_type == BoolType::get()) {
-        op = prim::Bool;
-      } else {
-        throw std::runtime_error(
-            "NYI: scalar types other than int, float, and bool are not supported yet");
-      }
-      rn_env[n->outputs()[0]] = res_graph->insert(op, {outputs[0]});
-    } else {
-      rn_env[n->outputs()[0]] = outputs[0];
-    }
-  } else {
-    for (size_t i = 0; i < n->outputs().size(); i++) {
-      auto output = n->outputs()[i];
-      batch_map[output] = std::vector<Value*>(
-          outputs.begin() + i * EXP_BTENSOR_SIZE,
-          outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
-    }
-  }
-}
-
-// clone prim::Constant to new graph
-// batching transformation is applied to the output of prim::NumToTensor.
-// If there is a prim::NumToTensor following prim::Constant, it will be finally
-// transformed to BatchTensor.
-void ToBatch::visitConstant(Node* n, Block* block, Block* res_block) {
-  auto res_graph = res_block->owningGraph();
-  auto* r_node = res_graph->createClone(n, rn_fn);
-  res_block->appendNode(r_node);
-  rn_env[n->output()] = r_node->output();
-}
-
-// change return tensor to expanded batched tensor, eg: {data, mask, dims}
-void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block) {
-  auto res_graph = res_block->owningGraph();
-  auto* r_node = res_graph->createClone(n, rn_fn);
-  res_block->appendNode(r_node);
-  auto outputs = inlineUnpackedCallTo(
-      *res_block->owningGraph(),
-      *getBatchOperator("batch_from_scalar_tensor"),
-      r_node->outputs());
-  batch_map[n->output()] = outputs;
-}
-
-// clone prim::TensorToNum to new graph
-void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block) {
-  auto res_graph = res_block->owningGraph();
-  if (rn_env.find(n->input()) == rn_env.end()) {
-    rn_env[n->input()] = batch_map.at(n->input())[0];
-  }
-  auto* r_node = res_graph->createClone(n, rn_fn);
-  res_block->appendNode(r_node);
-  rn_env[n->output()] = r_node->output();
-  batch_map[n->output()] = batch_map.at(n->input());
-}
-
-// clone prim::ListConstruct to new graph
-void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block) {
-  auto res_graph = res_block->owningGraph();
-  if (n->inputs()[0]->type() ==
-      TensorType::get()) { // TensorList: expand directly
-    std::vector<Value*> inputs;
-    for (Value* input : n->inputs()) {
-      auto res = batch_map.at(input);
-      inputs.insert(inputs.end(), res.begin(), res.end());
-    }
-    batch_map[n->output()] = inputs;
-  } else { // ScalarList: transform to tensor, then transform back
-    for (Value* input : n->inputs()) {
-      if (rn_env.find(input) == rn_env.end()) {
-        rn_env[input] = batch_map.at(input)[0];
-      }
-    }
-    auto* r_node = res_graph->createClone(n, rn_fn);
-    res_block->appendNode(r_node);
-    // transform int[] to tensor
-    auto to_tensor_node =
-        res_graph->create(Symbol::fromQualString("aten::_list_to_tensor"));
-    to_tensor_node->addInput(r_node->output());
-    res_block->appendNode(to_tensor_node);
-    rn_env[n->output()] = to_tensor_node->output();
-  }
-}
-
-// clang-format off
-// prim::If transformation:
-// elif is not supported
-//
-// transformation example:
-// @torch.jit.batch(batch_size=4)
-// def batch_if(a, b):
-//     if a > b:
-//         a += b
-//     else:
-//         a -= b
-//     return a
-//
-// original graph:
-// graph(%a.1 : Dynamic
-//       %b : Dynamic) {
-//   %2 : Dynamic = aten::gt(%a.1, %b)
-//   %a : Dynamic = prim::If(%2)
-//     block0() {
-//       %a.2 : Dynamic = aten::add[alpha={1}](%a.1, %b)
-//       -> (%a.2)
-//     }
-//     block1() {
-//       %a.3 : Dynamic = aten::sub[alpha={1}](%a.1, %b)
-//       -> (%a.3)
-//     }
-//   return (%a);
-// }
-//
-// transformed graph:
-// graph(%a.1_data : Dynamic
-//       %a.1_mask : Dynamic
-//       %a.1_dims : Dynamic
-//       %b_data : Dynamic
-//       %b_mask : Dynamic
-//       %b_dims : Dynamic) {
-//   %6 : Dynamic = aten::gt(%a.1_data, %b_data)  // calculate condition
-//   %7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
-//   %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
-//   %9 : int = prim::TensorToNum(%6)
-//   %10 : Long() = prim::Constant[value={1}]()  // if_block
-//   %alpha.1 : float = prim::TensorToNum(%10)
-//   %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
-//   %mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
-//   %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
-//   %15 : Long() = prim::Constant[value={1}]()  // else_block
-//   %alpha : float = prim::TensorToNum(%15)
-//   %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
-//   %mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
-//   %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
-//   %20 : Dynamic = aten::type_as(%7, %6)   // combine two outputs (batch_where)
-//   %cond_mask.1 : Dynamic = aten::mul(%6, %20)
-//   %22 : int = aten::dim(%cond_mask.1)
-//   %23 : int = prim::Constant[value=1]()
-//   %24 : int = aten::eq(%22, %23)
-//   %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%24)
-//     block0() {
-//       %28 : int = aten::dim(%data.1)
-//       %29 : int = prim::Constant[value=1]()
-//       %30 : int = aten::sub(%28, %29)
-//       %31 : int = prim::Constant[value=1]()
-//       %data.3 : Dynamic = prim::Loop(%30, %31, %cond_mask.1)
-//         block0(%_ : int, %34 : Dynamic) {
-//           %35 : int = prim::Constant[value=1]()
-//           %36 : int = aten::neg(%35)
-//           %data.2 : Dynamic = aten::unsqueeze(%34, %36)
-//           %38 : int = prim::Constant[value=1]()
-//           -> (%38, %data.2)
-//         }
-//       %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
-//       %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
-//       -> (%cond_data.1, %cond_mask.2, %data.3)
-//     }
-//     block1() {
-//       -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-//     }
-//   %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
-//   %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
-//   %res_dims : Dynamic = aten::__or__(%dims.1, %dims)
-//   return (%res_data, %res_mask, %res_dims);
-// }
-// clang-format on
-void ToBatch::visitIf(Node* n, Block* block, Block* res_block) {
-  toBatch(n->blocks()[0], res_block);
-  toBatch(n->blocks()[1], res_block);
-
-  // combine results from two if paths
-  for (size_t i = 0; i < n->outputs().size(); i++) {
-    std::vector<Value*> inputs;
-    if (batch_map.find(n->input()) == batch_map.end()) { // cond is scalar
-      inputs.push_back(rn_env.at(n->input()));
-    } else { // cond is tensor
-      auto cond = batch_map.at(n->input());
-      inputs.insert(inputs.end(), cond.begin(), cond.end());
-    }
-    auto if_output = batch_map.at(n->blocks()[0]->outputs()[i]);
-    inputs.insert(inputs.end(), if_output.begin(), if_output.end());
-    auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
-    inputs.insert(inputs.end(), else_output.begin(), else_output.end());
-    auto outputs = inlineUnpackedCallTo(
-        *res_block->owningGraph(),
-        *getBatchOperator("where", inputs.size()),
-        inputs);
-    batch_map[n->outputs()[i]] = outputs;
-  }
-}
-
-// clang-format off
-// prim::Loop transformation:
-//
-// transformation example:
-// @torch.jit.batch(batch_size=4)
-// def batch_while(a, b):
-//     while a > b:
-//         a -= b
-//     return a
-//
-// original graph:
-// graph(%a.1 : Dynamic
-//       %b : Dynamic) {
-//   %2 : int = prim::Constant[value={2147483647}]()
-//   %3 : Dynamic = aten::gt(%a.1, %b)
-//   %a : Dynamic = prim::Loop(%2, %3, %a.1)
-//     block0(%4 : Dynamic, %5 : Dynamic) {
-//       %a.2 : Dynamic = aten::sub[alpha={1}](%5, %b)
-//       %9 : Dynamic = aten::gt(%a.2, %b)
-//       -> (%9, %a.2)
-//     }
-//   return (%a);
-// }
-//
-// transformed graph:
-// graph(%a.1_data : Dynamic
-//       %a.1_mask : Dynamic
-//       %a.1_dims : Dynamic
-//       %b_data : Dynamic
-//       %b_mask : Dynamic
-//       %b_dims : Dynamic) {
-//   %6 : int = prim::Constant[value=2147483647]()
-//   %7 : Dynamic = aten::gt(%a.1_data, %b_data)
-//   %8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
-//   %9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
-//   %10 : int = prim::TensorToNum(%7)
-//   %11 : Dynamic = aten::mul(%7, %8)
-//   %12 : Dynamic = aten::sum(%11)
-//   %13 : Dynamic = aten::gt[other={0}](%12)  // cond_any
-//   %14 : int = prim::TensorToNum(%13)
-//   %62 : Dynamic, %63 : Dynamic, %64 : Dynamic, %a : Dynamic, %60 : Dynamic, %61 : Dynamic = prim::Loop(%6, %14, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
-//     block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
-//       %23 : Long() = prim::Constant[value={1}]()
-//       %alpha : float = prim::TensorToNum(%23)
-//       %data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha)
-//       %mask : Dynamic = aten::mul(%6_mask, %b_mask)
-//       %dims : Dynamic = aten::__or__(%6_dims, %b_dims)
-//       %28 : Dynamic = aten::gt(%data.1, %b_data)
-//       %29 : Dynamic = aten::mul(%mask, %b_mask)
-//       %30 : Dynamic = aten::__or__(%dims, %b_dims)
-//       %31 : int = prim::TensorToNum(%28)
-//       %32 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)  // update outputs (batch_where)
-//       %cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %32)
-//       %34 : int = aten::dim(%cond_mask.1)
-//       %35 : int = prim::Constant[value=1]()
-//       %36 : int = aten::eq(%34, %35)
-//       %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%36)
-//         block0() {
-//           %40 : int = aten::dim(%data.1)
-//           %41 : int = prim::Constant[value=1]()
-//           %42 : int = aten::sub(%40, %41)
-//           %43 : int = prim::Constant[value=1]()
-//           %data.3 : Dynamic = prim::Loop(%42, %43, %cond_mask.1)
-//             block0(%_ : int, %46 : Dynamic) {
-//               %47 : int = prim::Constant[value=1]()
-//               %48 : int = aten::neg(%47)
-//               %data.2 : Dynamic = aten::unsqueeze(%46, %48)
-//               %50 : int = prim::Constant[value=1]()
-//               -> (%50, %data.2)
-//             }
-//           %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
-//           %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
-//           -> (%cond_data.1, %cond_mask.2, %data.3)
-//         }
-//         block1() {
-//           -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-//         }
-//       %res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data)
-//       %res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask)
-//       %res_dims : Dynamic = aten::__or__(%dims, %6_dims)
-//       %56 : Dynamic = aten::mul(%28, %29)
-//       %57 : Dynamic = aten::sum(%56)
-//       %58 : Dynamic = aten::gt[other={0}](%57)
-//       %59 : int = prim::TensorToNum(%58)
-//       -> (%59, %28, %29, %30, %res_data, %res_mask, %res_dims)
-//     }
-//   return (%a, %60, %61);
-// }
-// clang-format on
-void ToBatch::visitLoop(Node* n, Block* block, Block* res_block) {
-  auto res_graph = res_block->owningGraph();
-  // bool cond_is_tensor indicates whether cond is tensor
-  // cond_is_tensor = false, eg: for loop, n->inputs()[1] = byte()
-  // cond_is_tensor = true, eg: in some while loop, cond is a batched tensor,
-  //                            we need to add expanded cond to the inputs of
-  //                            loop node and block, and compute cond_any as
-  //                            cond for while loop
-  bool cond_is_tensor = (batch_map.find(n->inputs()[1]) != batch_map.end());
-
-  // create prim::Loop node for res_block
-
-  // type of cond in loop should be int type
-  if (rn_env.at(n->inputs()[0])->type() != IntType::get()) {
-    rn_env[n->inputs()[0]] =
-        res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])});
-  }
-  if (cond_is_tensor) {
-    auto cond = batch_map.at(n->inputs()[1]);
-    auto cond_any = inlineUnpackedCallTo(
-        *res_block->owningGraph(), *getBatchOperator("any"), cond);
-    rn_env[n->inputs()[1]] = res_graph->insert(prim::Bool, {cond_any[0]});
-  }
-  for (size_t i = 2; i < n->inputs().size(); i++) {
-    auto input = n->inputs()[i];
-    rn_env[input] = batch_map.at(input)[0];
-  }
-  auto* r_node = res_graph->createClone(n, rn_fn, /*copy_blocks=*/false);
-
-  // change inputs of prim::Loop
-  if (cond_is_tensor) {
-    for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
-      auto cond = batch_map.at(n->inputs()[1]);
-      r_node->insertInput(i + 2, cond[i]);
-    }
-  }
-  for (size_t i = 2; i < n->inputs().size(); i++) {
-    for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
-      r_node->insertInput(
-          (i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 +
-              j,
-          batch_map.at(n->inputs()[i])[j]);
-    }
-  }
-  res_block->appendNode(r_node);
-
-  // create block for Loop node in res_block
-  // if cond is tensor:    first 4 inputs of block: cond_any, cond_data,
-  //                       cond_mask, cond_dims
-  // if cond is not tensor: first 1 input of block: cond
-  auto loop_block = r_node->addBlock();
-
-  // add inputs
-  loop_block->addInput("loop_num");
-  loop_block->inputs()[0]->setType(IntType::get());
-  rn_env[n->blocks()[0]->inputs()[0]] = loop_block->inputs()[0];
-  if (cond_is_tensor) {
-    for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
-      loop_block->addInput("cond_" + EXP_BTENSOR_NAME[i]);
-    }
-  }
-  for (size_t i = 1; i < n->blocks()[0]->inputs().size(); i++) {
-    auto input = n->blocks()[0]->inputs()[i];
-    auto name = input->uniqueName();
-    for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
-      loop_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
-    }
-    batch_map[input] =
-        std::vector<Value*>(loop_block->inputs()
-                                .slice(
-                                    (i - 1) * EXP_BTENSOR_SIZE + 1 +
-                                        EXP_BTENSOR_SIZE * cond_is_tensor,
-                                    EXP_BTENSOR_SIZE)
-                                .vec());
-  }
-
-  toBatch(n->blocks()[0], loop_block);
-
-  WithInsertPoint guard(loop_block);
-
-  // use where operator to update variables and add to outputs
-  for (size_t i = 0; i < n->outputs().size(); i++) {
-    std::vector<Value*> inputs, outputs;
-    if (cond_is_tensor) {
-      for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
-        inputs.push_back(loop_block->inputs()[j + 1]);
-      }
-      auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
-      inputs.insert(inputs.end(), data.begin(), data.end());
-      for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
-        inputs.push_back(
-            loop_block
-                ->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
-      }
-      outputs = inlineUnpackedCallTo(
-          *res_block->owningGraph(), *getBatchOperator("where"), inputs);
-    } else {
-      for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
-        inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + 1]);
-      }
-      auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
-      inputs.insert(inputs.end(), data.begin(), data.end());
-      outputs = inlineUnpackedCallTo(
-          *res_block->owningGraph(), *getBatchOperator("update"), inputs);
-    }
-    batch_map[n->outputs()[i]] = outputs;
-    for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
-      loop_block->registerOutput(outputs[j]);
-    }
-  }
-
-  // update loop conditions
-  if (cond_is_tensor) {
-    auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
-    auto cond_any = inlineUnpackedCallTo(
-        *res_block->owningGraph(), *getBatchOperator("any"), cond);
-    auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]});
-    loop_block->insertOutput(0, to_bool_output);
-    for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
-      loop_block->insertOutput(i + 1, cond[i]);
-    }
-  } else {
-    auto cond = rn_env.at(n->blocks()[0]->outputs()[0]);
-    loop_block->insertOutput(0, cond);
-  }
-
-  // change outputs of prim::Loop
-  auto size = r_node->outputs().size();
-  for (size_t i = 0; i < size; i++) {
-    for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
-      r_node->insertOutput(i * EXP_BTENSOR_SIZE + j);
-    }
-    batch_map[n->outputs()[i]] =
-        r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
-  }
-  // add cond to outputs of loop node
-  if (cond_is_tensor) {
-    for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
-      r_node->insertOutput(i);
-    }
-  }
-}
-
-void ToBatch::toBatch(Block* block, Block* res_block) {
-  WithInsertPoint guard(res_block);
-
-  // change inputs of block-expand tensor to batchtensor eg: (data, mask, dims)
-  // eg: a -> a_data, a_mask, a_dims for block in prim::Loop, register inputs
-  // separately to deal with cond
-  if (!block->owningNode() || block->owningNode()->kind() != prim::Loop) {
-    auto size = block->inputs().size();
-    for (size_t i = 0; i < size; i++) {
-      auto input = block->inputs()[i];
-      auto name = input->uniqueName();
-      for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
-        res_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
-      }
-      batch_map[input] =
-          std::vector<Value*>(res_block->inputs()
-                                  .slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE)
-                                  .vec());
-    }
-  }
-
-  for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
-    auto n = *it;
-    if (n->kind().is_aten()) {
-      visitAten(n, block, res_block);
-    } else if (n->kind().is_prim()) {
-      switch (n->kind()) {
-        case prim::Constant:
-          visitConstant(n, block, res_block);
-          break;
-        case prim::NumToTensor:
-          visitNumToTensor(n, block, res_block);
-          break;
-        case prim::Bool:
-        case prim::Float:
-        case prim::Int:
-          visitTensorToNum(n, block, res_block);
-          break;
-        case prim::ListConstruct:
-          visitListConstruct(n, block, res_block);
-          break;
-        case prim::If:
-          visitIf(n, block, res_block);
-          break;
-        case prim::Loop:
-          visitLoop(n, block, res_block);
-          break;
-        default:
-          throw std::runtime_error(
-              "NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
-      }
-    } else {
-      throw std::runtime_error(
-          "NYI: node that is not aten or prim kind is not supported yet");
-    }
-  }
-  // change outputs of block - expand tensor to batchtensor(data, mask, dims)
-  // for block in prim::Loop, register outputs separately to deal with cond and
-  // cond_any
-  //
-  // for block in prim::If, register outputs separately by combining
-  // outputs from two paths and return
-  if (!block->owningNode() ||
-      (block->owningNode()->kind() != prim::Loop &&
-       block->owningNode()->kind() != prim::If)) {
-    for (Value* output : block->outputs()) {
-      auto r_output = batch_map.at(output);
-      for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
-        res_block->registerOutput(r_output[i]);
-      }
-    }
-  }
-}
-
-std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
-  // lower the tuple before the pass
-  if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) {
-    graph = graph->copy();
-    auto outs = createTupleUnpack(graph->outputs().at(0));
-    graph->eraseOutput(0);
-    for (auto o : outs)
-      graph->registerOutput(o);
-    EliminateDeadCode(graph->block());
-  }
-  std::shared_ptr<Graph> res_graph = std::make_shared<Graph>();
-  ToBatch to_batch;
-  to_batch.toBatch(graph->block(), res_graph->block());
-
-  // methods should only have a single output, so we pack everything into a
-  // tuple
-  auto tup =
-      res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
-  while (res_graph->outputs().size() > 0)
-    res_graph->eraseOutput(res_graph->outputs().size() - 1);
-  res_graph->registerOutput(tup->output());
-  EliminateDeadCode(res_graph->block());
-
-  return res_graph;
-}
-
-void initRegisterBatchOpsBindings(PyObject* module) {
-  auto m = py::handle(module).cast<py::module>();
-  m.def("to_batch_graph", to_batch_graph);
-  m.def(
-      "register_batch_operator",
-      [](std::string name, std::shared_ptr<Graph> graph) {
-        ToBatch::batch_operator_table[name].push_back(graph);
-      });
-}
-
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/passes/to_batch.h b/torch/csrc/jit/passes/to_batch.h
deleted file mode 100644 (file)
index 76bf53d..0000000
+++ /dev/null
@@ -1,47 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/pybind.h>
-
-#include <ATen/ATen.h>
-
-namespace torch {
-namespace jit {
-
-class ToBatch {
- private:
-  // number of tensors to represent a expanded BatchTensor. {data, mask, dims}
-  // for now.
-  const size_t EXP_BTENSOR_SIZE = 3;
-  const std::vector<std::string> EXP_BTENSOR_NAME = {"data", "mask", "dims"};
-  // mapping from tensor in original graph to {data, mask, dims} in new graph
-  std::unordered_map<Value*, std::vector<Value*>> batch_map;
-  // mapping from input in original graph to new input in new graph - used in
-  // createClone
-  std::unordered_map<Value*, Value*> rn_env;
-  std::function<Value*(Value*)> rn_fn = [this](Value* v) {
-    return rn_env.at(v);
-  };
-
- private:
-  std::shared_ptr<Graph> getBatchOperator(
-      const std::string& name,
-      int64_t input_num = -1);
-  void visitAten(Node* n, Block* block, Block* res_block);
-  void visitConstant(Node* n, Block* block, Block* res_block);
-  void visitNumToTensor(Node* n, Block* block, Block* res_block);
-  void visitTensorToNum(Node* n, Block* block, Block* res_block);
-  void visitListConstruct(Node* n, Block* block, Block* res_block);
-  void visitIf(Node* n, Block* block, Block* res_block);
-  void visitLoop(Node* n, Block* block, Block* res_block);
-
- public:
-  static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
-      batch_operator_table;
-  TORCH_API void toBatch(Block* block, Block* res_block);
-};
-
-TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph);
-TORCH_API void initRegisterBatchOpsBindings(PyObject* module);
-} // namespace jit
-} // namespace torch
index d3efa4d..f4d1c89 100644 (file)
@@ -14,7 +14,6 @@
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/import_source.h>
 #include <torch/csrc/jit/passes/python_print.h>
-#include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/pybind_utils.h>
 #include <torch/csrc/jit/python_tracer.h>
 #include <torch/csrc/jit/script/parser.h>
index 9d30e88..ec895cf 100644 (file)
@@ -57,7 +57,6 @@ _flatten = torch._C._jit_flatten
 _unflatten = torch._C._jit_unflatten
 _jit_script_compile = torch._C._jit_script_compile
 _jit_script_class_compile = torch._C._jit_script_class_compile
-BatchTensor = torch._C._jit.BatchTensor
 
 Future = torch._C.Future
 _fork = torch._C.fork
@@ -800,38 +799,6 @@ def _is_weak_type(cls):
     return cls in _jit_internal.weak_types
 
 
-def batch(batch_size=1, optimize=True, _frames_up=0):
-    def decorator(fn):
-        if not _enabled:
-            return fn
-        import torch.jit.batchop
-        mod = script(fn, optimize, _frames_up)
-        res_graph = torch.to_batch_graph(mod.graph)
-        res_mod = ScriptModule()
-        res_mod._create_method_from_graph('forward', res_graph)
-
-        def wrapper(*args):
-            new_args = []
-            for arg in args:
-                if isinstance(arg, torch.Tensor):
-                    arg = BatchTensor(arg, batch_size)
-                if isinstance(arg, BatchTensor):
-                    new_args.extend([arg.get_data(), arg.get_mask(), arg.get_dims()])
-                else:
-                    new_args.append(arg)
-            res = res_mod(*new_args)
-            assert len(res) % 3 == 0
-            if len(res) % 3 != 0:
-                raise "non-batched-tensor output is not supported yet"
-            result = [BatchTensor(*res[i * 3: i * 3 + 3]) for i in range(len(res) // 3)]
-            if len(result) == 1:
-                return result[0]
-            return result
-        wrapper.__doc__ = fn.__doc__
-        return wrapper
-    return decorator
-
-
 # These OrderedDictWrapper classes replace the actual OrderedDicts in
 # module with versions that get/set properties inside of script::Module.
 # This allows us to reuse most of nn.Module while still storing the
diff --git a/torch/jit/batchop.py b/torch/jit/batchop.py
deleted file mode 100644 (file)
index b0af553..0000000
+++ /dev/null
@@ -1,528 +0,0 @@
-import torch
-from torch.jit import BatchTensor
-
-
-# TODO: there are some commented raise statements
-# when we support rasie exception in script, we want to check them
-@torch.jit.script
-def batch_tanh(data, mask, dims):
-    data = torch.tanh(data)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_sigmoid(data, mask, dims):
-    data = torch.sigmoid(data)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_relu(data, mask, dims):
-    data = torch.relu(data)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_neg(data, mask, dims):
-    data = torch.neg(data)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_neg_scalar(data):
-    return torch.neg(data)
-
-
-@torch.jit.script
-def batch_add(data1, mask1, dims1, data2, mask2, dims2, alpha_):
-    alpha = float(alpha_)
-    data = torch.add(data1, data2, alpha=alpha)
-    mask = mask1 * mask2
-    dims = dims1.__or__(dims2)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_add_scalar(data, mask, dims, other, alpha_):
-    alpha = float(alpha_)
-    data = torch.add(data, other.type_as(data), alpha=alpha)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_sub(data1, mask1, dims1, data2, mask2, dims2, alpha_):
-    alpha = float(alpha_)
-    data = torch.sub(data1, data2, alpha=alpha)
-    mask = mask1 * mask2
-    dims = dims1.__or__(dims2)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_sub_scalar(data1, data2):
-    return data1 - data2
-
-
-@torch.jit.script
-def batch_mul(data1, mask1, dims1, data2, mask2, dims2):
-    data = torch.mul(data1, data2)
-    mask = mask1 * mask2
-    dims = dims1.__or__(dims2)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_mul_scalar(data1, data2):
-    return data1 * data2
-
-
-@torch.jit.script
-def batch_div(data, mask, dims, other):  # div(batchtensor, scalar)
-    data = torch.div(data, other)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_mm(data1, mask1, dims1, data2, mask2, dims2):
-    data1 = data1 * mask1.type_as(data1)
-    data2 = data2 * mask2.type_as(data2)
-    data = torch.bmm(data1, data2)
-    mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
-    dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_matmul(data1, mask1, dims1, data2, mask2, dims2):
-    d1 = data1.dim() - 1
-    d2 = data2.dim() - 1
-    data1 = data1 * mask1.type_as(data1)
-    data2 = data2 * mask2.type_as(data2)
-    if d1 == 1:
-        data1 = data1.unsqueeze(-2)
-    if d2 == 1:
-        data2 = data2.unsqueeze(-1)
-    data = torch.bmm(data1, data2)
-    mask = mask1
-    dims = dims1
-    if d1 == 1 and d2 == 1:
-        # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask).all():
-        #    raise ValueError("cannot contract non-matching dimensions")
-        data = data.squeeze(-1).squeeze(-1)
-        mask = mask1.narrow(1, 0, 1).squeeze(-1)
-        dims = dims1[:0]  # empty tensor
-    if d1 == 2 and d2 == 1:
-        # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask).all():
-        #    raise ValueError("cannot contract non-matching dimensions")
-        data = data.squeeze(-1)
-        mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1).unsqueeze(-1)).squeeze(-1)
-        dims = dims1[:1]
-    elif d1 == 1 and d2 == 2:
-        # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask[:, :, 0]).all():
-        #    raise ValueError("cannot contract non-matching dimensions")
-        data = data.squeeze(-2)
-        mask = torch.bmm(mask1.narrow(1, 0, 1).unsqueeze(-2), mask2.narrow(1, 0, 1)).squeeze(-2)
-        dims = dims2[1:dims2.size(0)]
-    elif d1 == 2 and d2 == 2:
-        # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask[:, :, 0]).all():
-        #    raise ValueError("cannot contract non-matching dimensions")
-        mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
-        dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
-    # else:
-    #     raise NotImplementedError("matmul not implemented with batches of 3+D tensors")
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_select(data, mask, dims, dim_, index_):
-    dim = int(dim_)
-    index = int(index_)
-    # if dim == 0:
-    #     raise ValueError("Cannot select 0 dim in BatchTensor")
-    data = data.select(dim, index)
-    if bool(dims[dim - 1]):
-        mask = mask.select(dim, index)
-    else:
-        mask = mask.select(dim, 0)
-    dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_fmod(data, mask, dims, other_):
-    other = int(other_)
-    data = torch.fmod(data, other)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_zeros_like(data, mask, dims):
-    res_data = torch.zeros_like(data)
-    return res_data, mask, dims
-
-
-@torch.jit.script
-def batch_index_select(data, mask, dims, dim_, index_data, index_mask, index_dims):
-    dim = int(dim_)
-    # if dim == 0:
-    #     raise ValueError("Cannot index_select along 0 dim in BatchTensor")
-    batch_size = data.size(0)  # TODO maybe index_mask will be used at some point
-    res_data = torch.zeros([0])
-    res_mask = torch.zeros([0])
-    for i in range(batch_size):
-        d = data[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
-        if bool(dims[dim - 1]):
-            m = mask[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
-        else:
-            m = mask[i].unsqueeze(0)
-        if i == 0:
-            res_data = d
-            res_mask = m
-        else:
-            res_data = torch.cat((res_data, d), 0)
-            res_mask = torch.cat((res_mask, m), 0)
-    return res_data, res_mask, dims
-
-
-@torch.jit.script
-def batch_view_as(data, mask, dims, data1, mask1, dims1):
-    # if data.size(0) != data1.size(0):
-    #     raise ValueError("In view_as, tensor and target tensor should have the same batch_size")
-    # if not torch.equal(dims, dims1):
-    #     raise ValueError("In batched view_as, dims and target dims should be the same")
-    data = data.view_as(data1)
-    mask = mask.view_as(mask1)
-    dims = dims1
-    return data, mask, dims
-
-
-# assume data, data1, data2 have same size
-@torch.jit.script
-def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
-    data = data * mask.type_as(data)
-    cond_data = data
-    cond_mask = data
-    if data.dim() == 1:
-        for _ in range(data1.dim() - 1):
-            data = data.unsqueeze(data.dim())
-        cond_data = data.expand_as(data1)
-        cond_mask = data.expand_as(mask1)
-    res_data = torch.where(cond_data, data1, data2)
-    res_mask = torch.where(cond_mask, mask1, mask2)
-    res_dims = dims1.__or__(dims2)
-    return res_data, res_mask, res_dims
-
-
-@torch.jit.script
-def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2):
-    cond = torch.zeros([1], dtype=torch.uint8)
-    res_data = torch.where(cond, data1, data2)
-    res_mask = torch.where(cond, mask1, mask2)
-    res_dims = torch.where(cond, dims1, dims2)
-    return res_data, res_mask, res_dims
-
-
-@torch.jit.script
-def batch_update(batch_data, batch_mask, batch_dims, new_data, new_mask, new_dims):
-    data = torch.where(new_mask, new_data, batch_data)
-    return data, new_mask, new_dims  # TODO: consider whether return new_mask and new_dims
-
-
-@torch.jit.script
-def batch_any(data, mask, dims):
-    return torch.gt(torch.sum(data * mask), 0)
-
-
-@torch.jit.script
-def batch_type_as(data, mask, dims, data1, mask1, dims1):
-    return data.type_as(data1), mask, dims
-
-
-@torch.jit.script
-def batch_gt(data, mask, dims, data1, mask1, dims1):
-    return torch.gt(data, data1), mask * mask1, dims.__or__(dims1)
-
-
-@torch.jit.script
-def batch_gt_scalar(data1, data2):
-    return torch.gt(data1, data2)
-
-
-@torch.jit.script
-def batch_gt_one_scalar(data, mask, dims, other_):
-    other = float(other_)
-    return torch.gt(data, other), mask, dims
-
-
-@torch.jit.script
-def batch_lt(data, mask, dims, data1, mask1, dims1):
-    return torch.lt(data, data1), mask * mask1, dims.__or__(dims1)
-
-
-@torch.jit.script
-def batch_eq(data, mask, dims, data1, mask1, dims1):
-    return torch.eq(data, data1), mask * mask1, dims.__or__(dims1)
-
-
-@torch.jit.script
-def batch_size(data, mask, dims, dim_):
-    dim = int(dim_)
-    return data.size(dim)
-
-
-@torch.jit.script
-def batch_dim(data, mask, dims):
-    return data.dim()
-
-
-@torch.jit.script
-def batch_squeeze(data, mask, dims, dim_):
-    if int(dim_) < 0:
-        dim_ = dim_ + data.dim()
-    dim = int(dim_)
-    # if dim == 0:
-    #     raise ValueError("cannot do squeeze along batch_dim")
-    data = data.squeeze(dim)
-    mask = mask.squeeze(dim)
-    dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_unsqueeze(data, mask, dims, dim_):
-    if int(dim_) < 0:
-        dim_ = dim_ + data.dim() + 1
-    dim = int(dim_)
-    # if dim == 0:
-    #     raise ValueError("cannot do unsqueeze along batch_dim")
-    data = data.unsqueeze(dim)
-    mask = mask.unsqueeze(dim)
-    dims = torch.cat((dims[:dim], torch.zeros([1], dtype=torch.uint8), dims[dim:dims.size(0)]))
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_argmax(data, mask, dims, dim_, keepdim_):
-    dim = int(dim_)
-    keepdim = bool(keepdim_)
-    # if dim == 0:
-    #     raise ValueError("cannot do argmax along batch_dim")
-    batch_size = data.size(0)
-    res_data = torch.zeros([0])
-    for i in range(batch_size):
-        if bool(dims[dim - 1]):
-            if dim - 1 != 0:
-                m = mask[i].transpose(0, dim - 1)
-            else:
-                m = mask[i]
-            valid_num = m.sum(0, keepdim=True)
-            while(valid_num.dim() >= 1):
-                valid_num = valid_num[0]
-            d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
-        else:
-            d = data[i].unsqueeze(0)
-        d = d.argmax(dim, keepdim)
-        if i == 0:
-            res_data = d
-        else:
-            res_data = torch.cat([res_data, d], 0)
-    if keepdim:
-        mask = mask
-    else:
-        mask = mask.select(dim, 0)
-        dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
-    return res_data, mask, dims
-
-
-@torch.jit.script
-def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_):
-    k = int(k_)
-    dim = int(dim_)
-    largest = bool(largest_)
-    sorted = bool(sorted_)
-    # if dim == 0:
-    #     raise ValueError("cannot do topk along batch_dim")
-    batch_size = data.size(0)
-    res_data = torch.zeros([0])
-    res_index = torch.zeros([0])
-    for i in range(batch_size):
-        if bool(dims[dim - 1]):
-            if dim - 1 != 0:
-                m = mask[i].transpose(0, dim - 1)
-            else:
-                m = mask[i]
-            valid_num = m.sum(0, keepdim=True)
-            while(valid_num.dim() >= 1):
-                valid_num = valid_num[0]
-            d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
-        else:
-            d = data[i].unsqueeze(0)
-        d, idx = d.topk(k, dim, largest, sorted)
-        if i == 0:
-            res_data = d
-            res_index = idx
-        else:
-            res_data = torch.cat([res_data, d], 0)
-            res_index = torch.cat([res_index, idx], 0)
-    if bool(dims[dim - 1]):
-        mask = mask.narrow(dim, 0, k)
-    return res_data, mask, dims, res_index, mask, dims
-
-
-@torch.jit.script
-def batch_softmax(data, mask, dims, dim_):
-    dim = int(dim_)
-    # if dim == 0:
-    #     raise ValueError("cannot do softmax along batch_dim")
-    batch_size = data.size(0)
-    max_len = data.size(dim)
-    res_data = torch.zeros([0])
-    for i in range(batch_size):
-        if bool(dims[dim - 1]):
-            if dim - 1 != 0:
-                m = mask[i].transpose(0, dim - 1)
-            else:
-                m = mask[i]
-            valid_num = m.sum(0, keepdim=True)
-            while(valid_num.dim() >= 1):
-                valid_num = valid_num[0]
-            valid_num = int(valid_num)
-            d = data[i].unsqueeze(0).narrow(dim, 0, valid_num).softmax(dim)
-            if valid_num < max_len:
-                d = torch.cat([d, data[i].unsqueeze(0).narrow(dim, valid_num, max_len - valid_num)], dim)
-        else:
-            d = data[i].unsqueeze(0).softmax(dim)
-        if i == 0:
-            res_data = d
-        else:
-            res_data = torch.cat([res_data, d], 0)
-    return res_data, mask, dims
-
-
-# size argument in dynamic dimension has to be -1
-# in static dimension, size has to be specified, -1 is not supported
-@torch.jit.script
-def batch_view(data, mask, dims, sizes):
-    batch_size = data.size(0)
-    # if(sizes[0] != batch_size and sizes[0] != -1 and sizes[0] != 1):
-    #     raise "first dim in view must be 1, -1, or batch size"
-    # for i in range(dims.size(0)):
-    #     if dims[0] == 1 and sizes[i + 1] != -1:
-    #         raise "size argument in dynamic dimension has to be -1"
-    sizes = sizes.type_as(torch.ones([1], dtype=torch.int))
-    data_sizes_ = torch.cat([torch.ones([1], dtype=torch.int) * batch_size, sizes.narrow(0, 1, sizes.size(0) - 1)], 0)
-    data_sizes = data_sizes_._tensor_to_list()
-    res_data = data.view(data_sizes)
-    mask_sizes_ = data_sizes_.narrow(0, 0, 1)
-    res_dims = data_sizes_.narrow(0, 0, 1)
-    for i_ in range(sizes.size(0) - 1):
-        i = i_ + 1
-        if bool(sizes[i] == -1):
-            cur_size_ = mask.size(i)
-            cur_dim = 1
-        else:
-            cur_size_ = 1
-            cur_dim = 0
-        mask_sizes_ = torch.cat([mask_sizes_, torch.ones([1], dtype=torch.int) * cur_size_])
-        res_dims = torch.cat([res_dims, torch.ones([1], dtype=torch.int) * cur_dim])
-    mask_sizes = mask_sizes_._tensor_to_list()
-    res_mask = mask.view(mask_sizes)
-    return res_data, res_mask, res_dims.narrow(0, 1, res_dims.size(0) - 1).type_as(dims)
-
-
-@torch.jit.script
-def batch_cat2(data1, mask1, dims1, data2, mask2, dims2, dim_):
-    dim = int(dim_)
-    data = torch.cat([data1, data2], dim)
-    if bool(dims1[dim - 1]):
-        mask = torch.cat([mask1, mask2], dim)
-    else:
-        mask = mask1
-    return data, mask, dims1
-
-
-@torch.jit.script
-def batch_cat3(data1, mask1, dims1, data2, mask2, dims2, data3, mask3, dims3, dim_):
-    dim = int(dim_)
-    data = torch.cat([data1, data2, data3], dim)
-    if bool(dims1[dim - 1]):
-        mask = torch.cat([mask1, mask2, mask3], dim)
-    else:
-        mask = mask1
-    return data, mask, dims1
-
-
-@torch.jit.script
-def batch_narrow(data, mask, dims, dimension_, start_, length_):
-    dimension = int(dimension_)
-    start = int(start_)
-    length = int(length_)
-    # if dimension == 0:
-    #     raise ValueError("cannot do narrow along batch_dim")
-    data = data.narrow(dimension, start, length)
-    if bool(dims[dimension - 1]):
-        mask = mask.narrow(dimension, start, length)
-    else:
-        mask = mask.narrow(dimension, 0, 1)
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_sum(data, mask, dims):
-    data = data * mask.type_as(data)
-    for _ in range(dims.size(0)):
-        data = data.sum(1)
-    mask = torch.ones([data.size(0)], dtype=torch.uint8)
-    dims = dims[:0]  # empty tensor
-    return data, mask, dims
-
-
-@torch.jit.script
-def batch_from_scalar_tensor(data):
-    data = data.unsqueeze(0)
-    mask = torch.ones([1], dtype=torch.uint8)
-    dims = torch.zeros([0], dtype=torch.uint8)
-    return data, mask, dims
-
-torch.register_batch_operator("tanh", batch_tanh.graph)
-torch.register_batch_operator("sigmoid", batch_sigmoid.graph)
-torch.register_batch_operator("relu", batch_relu.graph)
-torch.register_batch_operator("neg", batch_neg.graph)
-torch.register_batch_operator("neg", batch_neg_scalar.graph)
-torch.register_batch_operator("add", batch_add.graph)
-torch.register_batch_operator("add", batch_add_scalar.graph)
-torch.register_batch_operator("sub", batch_sub.graph)
-torch.register_batch_operator("sub", batch_sub_scalar.graph)
-torch.register_batch_operator("mul", batch_mul.graph)
-torch.register_batch_operator("mul", batch_mul_scalar.graph)
-torch.register_batch_operator("div", batch_div.graph)
-torch.register_batch_operator("matmul", batch_matmul.graph)
-torch.register_batch_operator("mm", batch_mm.graph)
-torch.register_batch_operator("fmod", batch_fmod.graph)
-torch.register_batch_operator("zeros_like", batch_zeros_like.graph)
-torch.register_batch_operator("select", batch_select.graph)
-torch.register_batch_operator("index_select", batch_index_select.graph)
-torch.register_batch_operator("view_as", batch_view_as.graph)
-torch.register_batch_operator("where", batch_where.graph)
-torch.register_batch_operator("where", batch_where_scalar.graph)
-torch.register_batch_operator("update", batch_update.graph)
-torch.register_batch_operator("any", batch_any.graph)
-torch.register_batch_operator("type_as", batch_type_as.graph)
-torch.register_batch_operator("gt", batch_gt.graph)
-torch.register_batch_operator("gt", batch_gt_scalar.graph)
-torch.register_batch_operator("gt", batch_gt_one_scalar.graph)
-torch.register_batch_operator("lt", batch_lt.graph)
-torch.register_batch_operator("eq", batch_eq.graph)
-torch.register_batch_operator("size", batch_size.graph)
-torch.register_batch_operator("dim", batch_dim.graph)
-torch.register_batch_operator("squeeze", batch_squeeze.graph)
-torch.register_batch_operator("unsqueeze", batch_unsqueeze.graph)
-torch.register_batch_operator("argmax", batch_argmax.graph)
-torch.register_batch_operator("topk", batch_topk.graph)
-torch.register_batch_operator("softmax", batch_softmax.graph)
-torch.register_batch_operator("view", batch_view.graph)
-torch.register_batch_operator("cat", batch_cat2.graph)
-torch.register_batch_operator("cat", batch_cat3.graph)
-torch.register_batch_operator("narrow", batch_narrow.graph)
-torch.register_batch_operator("sum", batch_sum.graph)
-torch.register_batch_operator("batch_from_scalar_tensor", batch_from_scalar_tensor.graph)