Enable tpu.rewrite to work on XLA CPU/GPU backends.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 22 May 2018 06:37:12 +0000 (23:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 06:40:05 +0000 (23:40 -0700)
PiperOrigin-RevId: 197517946

tensorflow/compiler/tf2xla/cc/BUILD
tensorflow/contrib/tpu/ops/replication_ops.cc
tensorflow/contrib/tpu/python/tpu/tpu.py

index 4f8bb8a..ea8d1b3 100644 (file)
@@ -27,3 +27,25 @@ cc_library(
         "//tensorflow/core:protos_all_cc",
     ],
 )
+
+tf_gen_op_wrapper_cc(
+    name = "xla_jit_op_gen",
+    out_ops_file = "ops/xla_jit_op",
+    deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+)
+
+cc_library(
+    name = "xla_jit_ops",
+    srcs = ["ops/xla_jit_op.cc"],
+    hdrs = ["ops/xla_jit_op.h"],
+    deps = [
+        "//tensorflow/cc:const_op",
+        "//tensorflow/cc:ops",
+        "//tensorflow/cc:scope",
+        "//tensorflow/compiler/jit/ops:xla_ops",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
index defed00..ab2a7a0 100644 (file)
@@ -25,6 +25,7 @@ using shape_inference::ShapeHandle;
 REGISTER_OP("TPUReplicateMetadata")
     .Attr("num_replicas: int >= 0")
     .Attr("topology: string = \"\"")
+    .Attr("use_tpu: bool = true")
     .Attr("device_assignment: list(int) = []")
     .Attr("computation_shape: list(int) = []")
     .Attr("host_compute_core: list(string) = []")
@@ -72,6 +73,7 @@ REGISTER_OP("TPUReplicate")
     .Attr("computation: func")
     .Attr("num_replicas: int >= 1")
     .Attr("topology: string = \"\"")
+    .Attr("use_tpu: bool = true")
     .Attr("device_assignment: list(int) = []")
     .Attr("host_compute_core: list(string) = []")
     .Attr("computation_shape: list(int) = []")
@@ -93,6 +95,9 @@ computation: a function containing the computation to run.
 num_replicas: the number of replicas of the computation to run.
 topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU
 topology.
+use_tpu: a bool indicating if this computation will run on TPU or CPU/GPU.
+Currently, only supports a default placement (computation is placed on GPU
+if one is available, and on CPU if not).
 computation_shape: a [mesh_dimension] array describing the shape of each
   computation replica in numbers of cores in the TPU mesh.
 device_assignment: a flattened array with shape
index c8f24ed..e2f57ce 100644 (file)
@@ -394,7 +394,8 @@ def split_compile_and_replicate(computation,
                                 inputs=None,
                                 infeed_queue=None,
                                 device_assignment=None,
-                                name=None):
+                                name=None,
+                                use_tpu=True):
   """Builds graph operators that runs compilation and replicated computation.
 
   This is a lower level interface than replicate that returns a separate compile
@@ -417,6 +418,9 @@ def split_compile_and_replicate(computation,
       only one core, and there is either only one replica, or the number of
       replicas is equal to the number of cores in the TPU system.
     name: (Deprecated) Does nothing.
+    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
+      backends. Currently, only supports a default placement (computation is
+      placed on GPU if one is available, and on CPU if not).
   Returns:
     A list of lists with the first list corresponding to the compile op and the
     second a list of output tensors, indexed by `[replica_num][output_num]`.
@@ -502,7 +506,7 @@ def split_compile_and_replicate(computation,
     context.Enter()
 
     metadata = tpu_ops.tpu_replicate_metadata(
-        num_replicas=num_replicas, **metadata_kwargs)
+        num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
 
     with tpu_function.tpu_shard_context(
         num_replicas), ops.control_dependencies([metadata]):
@@ -590,10 +594,13 @@ def split_compile_and_replicate(computation,
              for i in xrange(output_arity)]
 
   with ops.control_dependencies([metadata]):
-    compile_status = tpu_ops.tpu_compilation_result()
-    op = compile_status.op
-    attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
-    op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
+    if use_tpu:
+      compile_status = tpu_ops.tpu_compilation_result()
+      op = compile_status.op
+      attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
+      op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
+    else:
+      compile_status = control_flow_ops.no_op(name="compilation_status")
 
   with ops.control_dependencies(output_operations):
     if output_arity == 0: