A couple of small device-related utilities.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 24 Mar 2018 03:18:46 +0000 (20:18 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 25 Mar 2018 11:54:55 +0000 (04:54 -0700)
PiperOrigin-RevId: 190312148

tensorflow/python/BUILD
tensorflow/python/training/device_util.py [new file with mode: 0644]

index acfdcd1..e6ad564 100644 (file)
@@ -2884,9 +2884,11 @@ py_library(
         ":client",
         ":control_flow_ops",
         ":data_flow_ops",
+        ":device",
         ":errors",
         ":framework",
         ":framework_for_generated_wrappers",
+        ":framework_ops",
         ":gradients",
         ":init_ops",
         ":io_ops",
@@ -2911,6 +2913,7 @@ py_library(
         ":variable_scope",
         ":variables",
         "//tensorflow/python/eager:backprop",
+        "//tensorflow/python/eager:context",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py
new file mode 100644 (file)
index 0000000..f1137e8
--- /dev/null
@@ -0,0 +1,68 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Device-related support functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import device as tf_device
+from tensorflow.python.framework import ops
+
+
+def canonicalize(d):
+  d = tf_device.DeviceSpec.from_string(d)
+  assert d.device_type is None or d.device_type == d.device_type.upper(), (
+      "Device type '%s' must be all-caps." % (d.device_type,))
+  # Fill in missing device fields using defaults.
+  result = tf_device.DeviceSpec(
+      job="localhost", replica=0, task=0, device_type="CPU", device_index=0)
+  result.merge_from(d)
+  return result.to_string()
+
+
+class _FakeNodeDef(object):
+  """A fake NodeDef for _FakeOperation."""
+
+  def __init__(self):
+    self.op = ""
+    self.name = ""
+
+
+class _FakeOperation(object):
+  """A fake Operation object to pass to device functions."""
+
+  def __init__(self):
+    self.device = ""
+    self.type = ""
+    self.name = ""
+    self.node_def = _FakeNodeDef()
+
+  def _set_device(self, device):
+    self.device = ops._device_string(device)  # pylint: disable=protected-access
+
+
+def current():
+  """Return a string (not canonicalized) for the current device."""
+  # TODO(josh11b): Work out how this function interacts with ops.colocate_with.
+  ctx = context.context()
+  if ctx.executing_eagerly():
+    d = ctx.device_name
+  else:
+    op = _FakeOperation()
+    ops.get_default_graph()._apply_device_functions(op)  # pylint: disable=protected-access
+    d = op.device
+  return d