Fix bug in converted_call, and add tests for it.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Apr 2018 11:44:10 +0000 (04:44 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 11:46:40 +0000 (04:46 -0700)
PiperOrigin-RevId: 192751211

tensorflow/contrib/autograph/impl/api.py
tensorflow/contrib/autograph/impl/api_test.py

index a553813..a00d9c6 100644 (file)
@@ -156,7 +156,7 @@ def converted_call(f, recursive, verbose, arg_types, *args, **kwargs):
     # Constructors
     target_entity = f
     arg_map_target = f.__init__
-    effective_args = (unknown_arg_value,) + args
+    effective_args = args
     partial_types = ()
 
   elif hasattr(f, '__call__') and hasattr(f, '__class__'):
index f9db077..2e09d19 100644 (file)
@@ -179,6 +179,92 @@ class ApiTest(test.TestCase):
           constant_op.constant(-2))
       self.assertListEqual([0, 1], sess.run(x).tolist())
 
+  def test_converted_call_builtin(self):
+    x = api.converted_call(range, False, False, {}, 3)
+    self.assertEqual((0, 1, 2), tuple(x))
+
+  def test_converted_call_function(self):
+
+    def test_fn(x):
+      if x < 0:
+        return -x
+      return x
+
+    with self.test_session() as sess:
+      x = api.converted_call(
+          test_fn, False, False, {}, constant_op.constant(-1))
+      self.assertEqual(1, sess.run(x))
+
+  def test_converted_call_method(self):
+
+    class TestClass(object):
+
+      def __init__(self, x):
+        self.x = x
+
+      def test_method(self):
+        if self.x < 0:
+          return -self.x
+        return self.x
+
+    with self.test_session() as sess:
+      tc = TestClass(constant_op.constant(-1))
+      x = api.converted_call(tc.test_method, False, False, {}, tc)
+      self.assertEqual(1, sess.run(x))
+
+  def test_converted_call_method_by_class(self):
+
+    class TestClass(object):
+
+      def __init__(self, x):
+        self.x = x
+
+      def test_method(self):
+        if self.x < 0:
+          return -self.x
+        return self.x
+
+    with self.test_session() as sess:
+      tc = TestClass(constant_op.constant(-1))
+      x = api.converted_call(TestClass.test_method, False, False, {}, tc)
+      self.assertEqual(1, sess.run(x))
+
+  def test_converted_call_callable_object(self):
+
+    class TestClass(object):
+
+      def __init__(self, x):
+        self.x = x
+
+      def __call__(self):
+        if self.x < 0:
+          return -self.x
+        return self.x
+
+    with self.test_session() as sess:
+      tc = TestClass(constant_op.constant(-1))
+      x = api.converted_call(tc, False, False, {})
+      self.assertEqual(1, sess.run(x))
+
+  def test_converted_call_constructor(self):
+
+    class TestClass(object):
+
+      def __init__(self, x):
+        self.x = x
+
+      def test_method(self):
+        if self.x < 0:
+          return -self.x
+        return self.x
+
+    with self.test_session() as sess:
+      tc = api.converted_call(
+          TestClass, False, False, {}, constant_op.constant(-1))
+      # tc is now a converted object.
+      x = tc.test_method()
+      self.assertEqual(1, sess.run(x))
+
   def test_to_graph_basic(self):
 
     def test_fn(x, s):