from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
device_assignment=device_assignment,
name=name)[0]
# pylint: enable=indexing-exception
+
+ # Operations that indicate some error in the user's inference graph.
+_BLACKLISTED_INFERENCE_OPS = set([
+ "ReadVariableOp",
+ "AssignVariableOp",
+ "AssignAddVariableOp",
+ "AssignSubVariableOp",
+ "VarHandleOp",
+ "Variable",
+ "VariableV2",
+])
+
+
+class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
+ """A `ControlFlowContext` for nodes inside a TPU inference computation.
+
+ The primary role of `TPUReplicateContext` is to sanity check operators inside
+ a tpu.rewrite_for_inference() computation.
+ """
+
+ def __init__(self, name):
+ super(_TPUInferenceContext, self).__init__()
+ self._name = name
+
+ def AddOp(self, op):
+ self._AddOpInternal(op)
+
+ def _AddOpInternal(self, op):
+ # pylint: disable=protected-access
+ if op.type in _BLACKLISTED_INFERENCE_OPS:
+ raise NotImplementedError(
+ "Operation of type %s (%s) is not supported on the TPU for inference."
+ " Execution will fail if this op is used in the graph. Make sure your"
+ " variables are using variable_scope." % (op.type, op.name))
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ def AddValue(self, val):
+ result = val
+ if self._outer_context:
+ result = self._outer_context.AddValue(val)
+ return result
+
+ def AddInnerOp(self, op):
+ self._AddOpInternal(op)
+
+ @property
+ def grad_state(self):
+ return None
+
+
+@experimental
+def validate_inference_rewrite_for_variables(graph):
+ """Validates whether rewrite_for_inference() 'worked' for variables.
+
+ The rewrite_for_inference() method is supposed to append
+ GuaranteeConstOps after ReadVariableOps, but this mechanism works only
+ if you are using tf.get_variable() to create and access variables in your
+ tpu computation. This validation method can be called immediately after
+ calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps
+ where added to the graph.
+
+ Typical usages:
+ tpu.validate_inference_rewrite_for_variables(tf.get_default_graph())
+
+ tpu.validate_inference_rewrite_for_variables(sess.graph)
+
+ Args:
+ graph: The graph which needs to be validated.
+ Raises:
+ RuntimeError: if validation failed.
+ """
+ if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]):
+ raise RuntimeError(
+ "No GuaranteeConst ops found in the graph after "
+ "running tpu.rewrite_for_inference(...). Please "
+ "check that you are using tf.get_variable() to "
+ "create and access variables in your tpu "
+ "computation.")
+
+
+@experimental
+def rewrite_for_inference(computation,
+ inputs=None,
+ infeed_queue=None,
+ device_assignment=None,
+ name=None):
+ """Rewrites `computation` for inference on a TPU system.
+
+ Other than 'rewriting' the computation to run on a TPU, if using variables
+ in your computation, it moves the ReadVariableOps outside the TPU
+ computation, and adds GuaranteeConst ops just after the ReadVariableOps.
+ This mechanism works only if you are using tf.get_variable() to create and
+ access variables in your tpu computation. You can validate whether
+ this worked, by calling validate_inference_rewrite_for_variables() method
+ immediately after this method to check whether GuaranteeConstOps where
+ added to the graph.
+
+ Args:
+ computation: A Python function that builds a computation to apply
+ to the input. If the function takes n inputs, 'inputs' should be
+ a list of n tensors. If the function returns m outputs, rewrite
+ will return a list of m tensors.
+ inputs: A list of input tensors or `None` (equivalent to an empty list).
+ infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
+ of arguments as inputs to `computation`.
+ device_assignment: if not `None`, a `DeviceAssignment` describing the
+ mapping between logical cores in the computation with physical cores in
+ the TPU topology. May be omitted for a single-core computation, in which
+ case the core attached to task 0, TPU device 0 is used.
+ name: The name of the operator.
+ Returns:
+ A list of output tensors.
+ """
+
+ def guarantee_const_getter(getter, name, *args, **kwargs):
+ with ops.control_dependencies(None):
+ return array_ops.guarantee_const(
+ getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
+
+ def wrapped_computation(*args, **kwargs):
+ """Execute computation under `_TPUInferenceContext`."""
+ context = _TPUInferenceContext(
+ name=ops.get_default_graph().unique_name("rewrite_for_inference"))
+ try:
+ context.Enter()
+
+ vscope = variable_scope.get_variable_scope()
+ prev_custom_getter = vscope.custom_getter
+ prev_caching_device = vscope.caching_device
+ vscope.set_custom_getter(guarantee_const_getter)
+ vscope.set_caching_device(lambda op: op.device)
+
+ result = computation(*args, **kwargs)
+
+ vscope.set_custom_getter(prev_custom_getter)
+ vscope.set_caching_device(prev_caching_device)
+ finally:
+ context.Exit()
+ return result
+
+ # pylint: disable=undefined-variable
+ return rewrite(
+ wrapped_computation,
+ inputs=inputs,
+ infeed_queue=infeed_queue,
+ device_assignment=device_assignment,
+ name=name)
+ # pylint: enable=undefined-variable