From: A. Unique TensorFlower Date: Sat, 24 Mar 2018 03:18:46 +0000 (-0700) Subject: A couple of small device-related utilities. X-Git-Tag: tflite-v0.1.7~106^2^2~13 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=418ae5ed77f1353c794f93a4adfbf7db02fa3191;p=platform%2Fupstream%2Ftensorflow.git A couple of small device-related utilities. PiperOrigin-RevId: 190312148 --- diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index acfdcd1..e6ad564 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2884,9 +2884,11 @@ py_library( ":client", ":control_flow_ops", ":data_flow_ops", + ":device", ":errors", ":framework", ":framework_for_generated_wrappers", + ":framework_ops", ":gradients", ":init_ops", ":io_ops", @@ -2911,6 +2913,7 @@ py_library( ":variable_scope", ":variables", "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py new file mode 100644 index 0000000..f1137e8 --- /dev/null +++ b/tensorflow/python/training/device_util.py @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Device-related support functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops + + +def canonicalize(d): + d = tf_device.DeviceSpec.from_string(d) + assert d.device_type is None or d.device_type == d.device_type.upper(), ( + "Device type '%s' must be all-caps." % (d.device_type,)) + # Fill in missing device fields using defaults. + result = tf_device.DeviceSpec( + job="localhost", replica=0, task=0, device_type="CPU", device_index=0) + result.merge_from(d) + return result.to_string() + + +class _FakeNodeDef(object): + """A fake NodeDef for _FakeOperation.""" + + def __init__(self): + self.op = "" + self.name = "" + + +class _FakeOperation(object): + """A fake Operation object to pass to device functions.""" + + def __init__(self): + self.device = "" + self.type = "" + self.name = "" + self.node_def = _FakeNodeDef() + + def _set_device(self, device): + self.device = ops._device_string(device) # pylint: disable=protected-access + + +def current(): + """Return a string (not canonicalized) for the current device.""" + # TODO(josh11b): Work out how this function interacts with ops.colocate_with. + ctx = context.context() + if ctx.executing_eagerly(): + d = ctx.device_name + else: + op = _FakeOperation() + ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access + d = op.device + return d