w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
self.assertEqual(ys, ybs.examples())
+def execWrapper(code, glob, loc):
+ if PY2:
+ exec(code) in glob, loc
+ else:
+ exec(code, glob, loc)
class TestScript(JitTestCase):
@contextmanager
def run_test(code):
scope = {}
- exec(code, globals(), scope)
+ execWrapper(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
self.assertEqual(cu.func(), scope['func']())
code = template.format(lhs=args[0], rhs=args[1], op=op)
scope = {}
- exec(code, globals(), scope)
+ execWrapper(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
self.assertEqual(cu.func(tensor), scope['func'](tensor))
def test(op, args):
code = template.format(lhs=args[0], rhs=args[1], op=op)
scope = {}
- exec(code, globals(), scope)
+ execWrapper(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
self.assertEqual(
cu.func(),
def test(inp, typ, type_hint):
code = template.format(typ=typ, type_hint=type_hint)
scope = {}
- exec(code, globals(), scope)
+ execWrapper(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
self.assertEqual(
cu.func(inp),