constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
+ def test_converted_call_builtin(self):
+ x = api.converted_call(range, False, False, {}, 3)
+ self.assertEqual((0, 1, 2), tuple(x))
+
+ def test_converted_call_function(self):
+
+ def test_fn(x):
+ if x < 0:
+ return -x
+ return x
+
+ with self.test_session() as sess:
+ x = api.converted_call(
+ test_fn, False, False, {}, constant_op.constant(-1))
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_method(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(tc.test_method, False, False, {}, tc)
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_method_by_class(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(TestClass.test_method, False, False, {}, tc)
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_callable_object(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def __call__(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(tc, False, False, {})
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_constructor(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = api.converted_call(
+ TestClass, False, False, {}, constant_op.constant(-1))
+ # tc is now a converted object.
+ x = tc.test_method()
+ self.assertEqual(1, sess.run(x))
+
def test_to_graph_basic(self):
def test_fn(x, s):