From 4ba9e8eed9dfe0727db000bdd8be5384f39e6bd9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 23 May 2018 16:34:00 -0700 Subject: [PATCH] Open source rewrite_for_inference(). PiperOrigin-RevId: 197810460 --- tensorflow/contrib/tpu/python/tpu/tpu.py | 150 +++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index e2f57ce..f531ae5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -21,6 +21,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function @@ -867,3 +868,152 @@ def rewrite(computation, device_assignment=device_assignment, name=name)[0] # pylint: enable=indexing-exception + + # Operations that indicate some error in the user's inference graph. +_BLACKLISTED_INFERENCE_OPS = set([ + "ReadVariableOp", + "AssignVariableOp", + "AssignAddVariableOp", + "AssignSubVariableOp", + "VarHandleOp", + "Variable", + "VariableV2", +]) + + +class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): + """A `ControlFlowContext` for nodes inside a TPU inference computation. + + The primary role of `TPUReplicateContext` is to sanity check operators inside + a tpu.rewrite_for_inference() computation. + """ + + def __init__(self, name): + super(_TPUInferenceContext, self).__init__() + self._name = name + + def AddOp(self, op): + self._AddOpInternal(op) + + def _AddOpInternal(self, op): + # pylint: disable=protected-access + if op.type in _BLACKLISTED_INFERENCE_OPS: + raise NotImplementedError( + "Operation of type %s (%s) is not supported on the TPU for inference." + " Execution will fail if this op is used in the graph. Make sure your" + " variables are using variable_scope." % (op.type, op.name)) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + def AddValue(self, val): + result = val + if self._outer_context: + result = self._outer_context.AddValue(val) + return result + + def AddInnerOp(self, op): + self._AddOpInternal(op) + + @property + def grad_state(self): + return None + + +@experimental +def validate_inference_rewrite_for_variables(graph): + """Validates whether rewrite_for_inference() 'worked' for variables. + + The rewrite_for_inference() method is supposed to append + GuaranteeConstOps after ReadVariableOps, but this mechanism works only + if you are using tf.get_variable() to create and access variables in your + tpu computation. This validation method can be called immediately after + calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps + where added to the graph. + + Typical usages: + tpu.validate_inference_rewrite_for_variables(tf.get_default_graph()) + + tpu.validate_inference_rewrite_for_variables(sess.graph) + + Args: + graph: The graph which needs to be validated. + Raises: + RuntimeError: if validation failed. + """ + if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]): + raise RuntimeError( + "No GuaranteeConst ops found in the graph after " + "running tpu.rewrite_for_inference(...). Please " + "check that you are using tf.get_variable() to " + "create and access variables in your tpu " + "computation.") + + +@experimental +def rewrite_for_inference(computation, + inputs=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Rewrites `computation` for inference on a TPU system. + + Other than 'rewriting' the computation to run on a TPU, if using variables + in your computation, it moves the ReadVariableOps outside the TPU + computation, and adds GuaranteeConst ops just after the ReadVariableOps. + This mechanism works only if you are using tf.get_variable() to create and + access variables in your tpu computation. You can validate whether + this worked, by calling validate_inference_rewrite_for_variables() method + immediately after this method to check whether GuaranteeConstOps where + added to the graph. + + Args: + computation: A Python function that builds a computation to apply + to the input. If the function takes n inputs, 'inputs' should be + a list of n tensors. If the function returns m outputs, rewrite + will return a list of m tensors. + inputs: A list of input tensors or `None` (equivalent to an empty list). + infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple + of arguments as inputs to `computation`. + device_assignment: if not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. May be omitted for a single-core computation, in which + case the core attached to task 0, TPU device 0 is used. + name: The name of the operator. + Returns: + A list of output tensors. + """ + + def guarantee_const_getter(getter, name, *args, **kwargs): + with ops.control_dependencies(None): + return array_ops.guarantee_const( + getter(name, *args, **kwargs), name=name + "/GuaranteeConst") + + def wrapped_computation(*args, **kwargs): + """Execute computation under `_TPUInferenceContext`.""" + context = _TPUInferenceContext( + name=ops.get_default_graph().unique_name("rewrite_for_inference")) + try: + context.Enter() + + vscope = variable_scope.get_variable_scope() + prev_custom_getter = vscope.custom_getter + prev_caching_device = vscope.caching_device + vscope.set_custom_getter(guarantee_const_getter) + vscope.set_caching_device(lambda op: op.device) + + result = computation(*args, **kwargs) + + vscope.set_custom_getter(prev_custom_getter) + vscope.set_caching_device(prev_caching_device) + finally: + context.Exit() + return result + + # pylint: disable=undefined-variable + return rewrite( + wrapped_computation, + inputs=inputs, + infeed_queue=infeed_queue, + device_assignment=device_assignment, + name=name) + # pylint: enable=undefined-variable -- 2.7.4