From 028993236b2ee9674ab11294e9985d7beaf376bb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 21 May 2018 23:37:12 -0700 Subject: [PATCH] Enable tpu.rewrite to work on XLA CPU/GPU backends. PiperOrigin-RevId: 197517946 --- tensorflow/compiler/tf2xla/cc/BUILD | 22 ++++++++++++++++++++++ tensorflow/contrib/tpu/ops/replication_ops.cc | 5 +++++ tensorflow/contrib/tpu/python/tpu/tpu.py | 19 +++++++++++++------ 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 4f8bb8a..ea8d1b3 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -27,3 +27,25 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_gen_op_wrapper_cc( + name = "xla_jit_op_gen", + out_ops_file = "ops/xla_jit_op", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], +) + +cc_library( + name = "xla_jit_ops", + srcs = ["ops/xla_jit_op.cc"], + hdrs = ["ops/xla_jit_op.h"], + deps = [ + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/jit/ops:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index defed00..ab2a7a0 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -25,6 +25,7 @@ using shape_inference::ShapeHandle; REGISTER_OP("TPUReplicateMetadata") .Attr("num_replicas: int >= 0") .Attr("topology: string = \"\"") + .Attr("use_tpu: bool = true") .Attr("device_assignment: list(int) = []") .Attr("computation_shape: list(int) = []") .Attr("host_compute_core: list(string) = []") @@ -72,6 +73,7 @@ REGISTER_OP("TPUReplicate") .Attr("computation: func") .Attr("num_replicas: int >= 1") .Attr("topology: string = \"\"") + .Attr("use_tpu: bool = true") .Attr("device_assignment: list(int) = []") .Attr("host_compute_core: list(string) = []") .Attr("computation_shape: list(int) = []") @@ -93,6 +95,9 @@ computation: a function containing the computation to run. num_replicas: the number of replicas of the computation to run. topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU topology. +use_tpu: a bool indicating if this computation will run on TPU or CPU/GPU. +Currently, only supports a default placement (computation is placed on GPU +if one is available, and on CPU if not). computation_shape: a [mesh_dimension] array describing the shape of each computation replica in numbers of cores in the TPU mesh. device_assignment: a flattened array with shape diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index c8f24ed..e2f57ce 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -394,7 +394,8 @@ def split_compile_and_replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, - name=None): + name=None, + use_tpu=True): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile @@ -417,6 +418,9 @@ def split_compile_and_replicate(computation, only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. + use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU + backends. Currently, only supports a default placement (computation is + placed on GPU if one is available, and on CPU if not). Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. @@ -502,7 +506,7 @@ def split_compile_and_replicate(computation, context.Enter() metadata = tpu_ops.tpu_replicate_metadata( - num_replicas=num_replicas, **metadata_kwargs) + num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): @@ -590,10 +594,13 @@ def split_compile_and_replicate(computation, for i in xrange(output_arity)] with ops.control_dependencies([metadata]): - compile_status = tpu_ops.tpu_compilation_result() - op = compile_status.op - attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) - op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access + if use_tpu: + compile_status = tpu_ops.tpu_compilation_result() + op = compile_status.op + attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) + op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access + else: + compile_status = control_flow_ops.no_op(name="compilation_status") with ops.control_dependencies(output_operations): if output_arity == 0: -- 2.7.4