from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
+from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
for node in divide_fdef[0].node_def:
self.assertAllEqual(node.device, "/device:CPU:1")
+ def _testNestedDeviceWithSameFunction(self, func_name):
+
+ def MatmulWrap(a, b):
+
+ @function.Defun(
+ func_name=func_name, *[dtypes.int32] * 2)
+ def Matmul(a, b):
+ return math_ops.matmul(a, b)
+
+ return Matmul(a, b)
+
+ with ops.Graph().as_default(), ops.device("CPU:0"):
+ c = MatmulWrap(1, 2)
+
+ with ops.device("CPU:1"):
+ MatmulWrap(c, 3)
+
+ gdef = ops.get_default_graph().as_graph_def()
+
+ devices = []
+ for node in gdef.library.function[0].node_def:
+ devices.append(node.device)
+ for node in gdef.library.function[1].node_def:
+ devices.append(node.device)
+
+ self.assertAllEqual(sorted(devices), ["/device:CPU:0", "/device:CPU:1"])
+
+ def testFunctionWithName(self):
+ with self.assertRaises(InvalidArgumentError) as cm:
+ self._testNestedDeviceWithSameFunction("MatmulTest")
+ self.assertEqual(
+ cm.exception.message,
+ "Cannot add function \'MatmulTest\' because a different "
+ "function with the same name already exists.")
+
+ def testFunctionWithoutName(self):
+ self._testNestedDeviceWithSameFunction(None)
+
if __name__ == "__main__":
test.main()