From bf741007d1f6f440a2671b9fa8894af3df10ed44 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 20 Mar 2018 21:30:02 -0700 Subject: [PATCH] C API: fix device + colocation edge case in import_graph_def This change makes the C API consistent with the Python API, by making sure that all nodes in a colocation group have the device of the op named in the "_class" attr (all other ops' devices are ignored). This is currently done by preserving the current Python logic for colocation and devices, which only works if all ops start with no device set. Without this change, imported nodes would have the device specified in the GraphDef. This change unsets any device before running the Python logic. PiperOrigin-RevId: 189859688 --- tensorflow/python/framework/importer.py | 11 ++++--- tensorflow/python/framework/importer_test.py | 43 ++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index a9e399f..4ea34d7 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -301,14 +301,17 @@ def _ProcessNewOps(graph): colocation_pairs = {} for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access + original_device = new_op.device + new_op._set_device('') # pylint: disable=protected-access colocation_names = _GetColocationNames(new_op) if colocation_names: colocation_pairs[new_op] = colocation_names - # Don't apply this op's device function, since colocation constraints - # override device functions. Note that this op's device may still be set - # by the loop below. + # Don't set a device for this op, since colocation constraints override + # device functions and the original device. Note that this op's device may + # still be set by the loop below. + # TODO(skyewm): why does it override the original device? else: - with _MaybeDevice(new_op.device): + with _MaybeDevice(original_device): graph._apply_device_functions(new_op) # pylint: disable=protected-access # The following loop populates the device field of ops that are colocated diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index bf5d9fe..6593b17 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -680,6 +680,49 @@ class ImportGraphDefTest(test.TestCase): "list { s: 'loc:@imported_graph/A' }", b.node_def.attr["_class"]) + def testColocationAndDevice(self): + # A and B are colocated, device set on A. + original_graph_def = self._MakeGraphDef(""" + node { name: 'A' op: 'None' device: '/device:CPU:0' attr { + key: '_class' + value { list { s: 'loc:@A' } } + } } + node { name: 'B' op: 'None' attr { + key: '_class' + value { list { s: 'loc:@A' } } + } }""") + + with ops.Graph().as_default(): + a, b = importer.import_graph_def(original_graph_def, + return_elements=["A", "B"], + name="") + self.assertEqual(a.device, "/device:CPU:0") + self.assertEqual(b.device, "/device:CPU:0") + self.assertEqual(a.colocation_groups(), [b"loc:@A"]) + self.assertEqual(b.colocation_groups(), [b"loc:@A"]) + + # A and B are colocated, device set on B. + original_graph_def = self._MakeGraphDef(""" + node { name: 'A' op: 'None' attr { + key: '_class' + value { list { s: 'loc:@A' } } + } } + node { name: 'B' op: 'None' device: '/device:CPU:0' attr { + key: '_class' + value { list { s: 'loc:@A' } } + } }""") + + with ops.Graph().as_default(): + a, b = importer.import_graph_def(original_graph_def, + return_elements=["A", "B"], + name="") + # TODO(skyewm): this behavior seems inconsistent with the above. Why is + # B's device ignored? + self.assertEqual(a.device, "") + self.assertEqual(b.device, "") + self.assertEqual(a.colocation_groups(), [b"loc:@A"]) + self.assertEqual(b.colocation_groups(), [b"loc:@A"]) + def testColocationWithDeviceFn(self): original_graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'None' attr { -- 2.7.4