Add more test cases in function_test
authorYoulong Cheng <ylc@google.com>
Thu, 17 May 2018 22:47:24 +0000 (15:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 22:49:58 +0000 (15:49 -0700)
PiperOrigin-RevId: 197064629

tensorflow/python/framework/function_test.py

index 88f6a36..15e41ba 100644 (file)
@@ -36,6 +36,7 @@ from tensorflow.python.framework import graph_to_function_def
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
+from tensorflow.python.framework.errors import InvalidArgumentError
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import functional_ops
@@ -1764,6 +1765,44 @@ class DevicePlacementTest(test.TestCase):
       for node in divide_fdef[0].node_def:
         self.assertAllEqual(node.device, "/device:CPU:1")
 
+  def _testNestedDeviceWithSameFunction(self, func_name):
+
+    def MatmulWrap(a, b):
+
+      @function.Defun(
+          func_name=func_name, *[dtypes.int32] * 2)
+      def Matmul(a, b):
+        return math_ops.matmul(a, b)
+
+      return Matmul(a, b)
+
+    with ops.Graph().as_default(), ops.device("CPU:0"):
+      c = MatmulWrap(1, 2)
+
+      with ops.device("CPU:1"):
+        MatmulWrap(c, 3)
+
+      gdef = ops.get_default_graph().as_graph_def()
+
+      devices = []
+      for node in gdef.library.function[0].node_def:
+        devices.append(node.device)
+      for node in gdef.library.function[1].node_def:
+        devices.append(node.device)
+
+      self.assertAllEqual(sorted(devices), ["/device:CPU:0", "/device:CPU:1"])
+
+  def testFunctionWithName(self):
+    with self.assertRaises(InvalidArgumentError) as cm:
+      self._testNestedDeviceWithSameFunction("MatmulTest")
+    self.assertEqual(
+        cm.exception.message,
+        "Cannot add function \'MatmulTest\' because a different "
+        "function with the same name already exists.")
+
+  def testFunctionWithoutName(self):
+    self._testNestedDeviceWithSameFunction(None)
+
 
 if __name__ == "__main__":
   test.main()