Open source rewrite_for_inference().
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 23 May 2018 23:34:00 +0000 (16:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 23:38:07 +0000 (16:38 -0700)
PiperOrigin-RevId: 197810460

tensorflow/contrib/tpu/python/tpu/tpu.py

index e2f57ce..f531ae5 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import print_function
 
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
+from tensorflow.contrib.framework.python.framework import experimental
 from tensorflow.contrib.tpu.python.ops import tpu_ops
 from tensorflow.contrib.tpu.python.tpu import tpu_function
 
@@ -867,3 +868,152 @@ def rewrite(computation,
       device_assignment=device_assignment,
       name=name)[0]
   # pylint: enable=indexing-exception
+
+  # Operations that indicate some error in the user's inference graph.
+_BLACKLISTED_INFERENCE_OPS = set([
+    "ReadVariableOp",
+    "AssignVariableOp",
+    "AssignAddVariableOp",
+    "AssignSubVariableOp",
+    "VarHandleOp",
+    "Variable",
+    "VariableV2",
+])
+
+
+class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
+  """A `ControlFlowContext` for nodes inside a TPU inference computation.
+
+  The primary role of `TPUReplicateContext` is to sanity check operators inside
+  a tpu.rewrite_for_inference() computation.
+  """
+
+  def __init__(self, name):
+    super(_TPUInferenceContext, self).__init__()
+    self._name = name
+
+  def AddOp(self, op):
+    self._AddOpInternal(op)
+
+  def _AddOpInternal(self, op):
+    # pylint: disable=protected-access
+    if op.type in _BLACKLISTED_INFERENCE_OPS:
+      raise NotImplementedError(
+          "Operation of type %s (%s) is not supported on the TPU for inference."
+          " Execution will fail if this op is used in the graph. Make sure your"
+          " variables are using variable_scope." % (op.type, op.name))
+    if self._outer_context:
+      self._outer_context.AddInnerOp(op)
+
+  def AddValue(self, val):
+    result = val
+    if self._outer_context:
+      result = self._outer_context.AddValue(val)
+    return result
+
+  def AddInnerOp(self, op):
+    self._AddOpInternal(op)
+
+  @property
+  def grad_state(self):
+    return None
+
+
+@experimental
+def validate_inference_rewrite_for_variables(graph):
+  """Validates whether rewrite_for_inference() 'worked' for variables.
+
+     The rewrite_for_inference() method is supposed to append
+     GuaranteeConstOps after ReadVariableOps, but this mechanism works only
+     if you are using tf.get_variable() to create and access variables in your
+     tpu computation. This validation method can be called immediately after
+     calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps
+     where added to the graph.
+
+     Typical usages:
+       tpu.validate_inference_rewrite_for_variables(tf.get_default_graph())
+
+       tpu.validate_inference_rewrite_for_variables(sess.graph)
+
+  Args:
+    graph: The graph which needs to be validated.
+  Raises:
+    RuntimeError: if validation failed.
+  """
+  if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]):
+    raise RuntimeError(
+        "No GuaranteeConst ops found in the graph after "
+        "running tpu.rewrite_for_inference(...). Please "
+        "check that you are using tf.get_variable() to "
+        "create and access variables in your tpu "
+        "computation.")
+
+
+@experimental
+def rewrite_for_inference(computation,
+                          inputs=None,
+                          infeed_queue=None,
+                          device_assignment=None,
+                          name=None):
+  """Rewrites `computation` for inference on a TPU system.
+
+     Other than 'rewriting' the computation to run on a TPU, if using variables
+     in your computation, it moves the ReadVariableOps outside the TPU
+     computation, and adds GuaranteeConst ops just after the ReadVariableOps.
+     This mechanism works only if you are using tf.get_variable() to create and
+     access variables in your tpu computation. You can validate whether
+     this worked, by calling validate_inference_rewrite_for_variables() method
+     immediately after this method to check whether GuaranteeConstOps where
+     added to the graph.
+
+  Args:
+    computation: A Python function that builds a computation to apply
+      to the input. If the function takes n inputs, 'inputs' should be
+      a list of n tensors. If the function returns m outputs, rewrite
+      will return a list of m tensors.
+    inputs: A list of input tensors or `None` (equivalent to an empty list).
+    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
+      of arguments as inputs to `computation`.
+    device_assignment: if not `None`, a `DeviceAssignment` describing the
+      mapping between logical cores in the computation with physical cores in
+      the TPU topology. May be omitted for a single-core computation, in which
+      case the core attached to task 0, TPU device 0 is used.
+    name: The name of the operator.
+  Returns:
+    A list of output tensors.
+  """
+
+  def guarantee_const_getter(getter, name, *args, **kwargs):
+    with ops.control_dependencies(None):
+      return array_ops.guarantee_const(
+          getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
+
+  def wrapped_computation(*args, **kwargs):
+    """Execute computation under `_TPUInferenceContext`."""
+    context = _TPUInferenceContext(
+        name=ops.get_default_graph().unique_name("rewrite_for_inference"))
+    try:
+      context.Enter()
+
+      vscope = variable_scope.get_variable_scope()
+      prev_custom_getter = vscope.custom_getter
+      prev_caching_device = vscope.caching_device
+      vscope.set_custom_getter(guarantee_const_getter)
+      vscope.set_caching_device(lambda op: op.device)
+
+      result = computation(*args, **kwargs)
+
+      vscope.set_custom_getter(prev_custom_getter)
+      vscope.set_caching_device(prev_caching_device)
+    finally:
+      context.Exit()
+    return result
+
+  # pylint: disable=undefined-variable
+  return rewrite(
+      wrapped_computation,
+      inputs=inputs,
+      infeed_queue=infeed_queue,
+      device_assignment=device_assignment,
+      name=name)
+  # pylint: enable=undefined-variable