"//tensorflow/core:protos_all_cc",
],
)
+
+tf_gen_op_wrapper_cc(
+ name = "xla_jit_op_gen",
+ out_ops_file = "ops/xla_jit_op",
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+)
+
+cc_library(
+ name = "xla_jit_ops",
+ srcs = ["ops/xla_jit_op.cc"],
+ hdrs = ["ops/xla_jit_op.h"],
+ deps = [
+ "//tensorflow/cc:const_op",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/compiler/jit/ops:xla_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
REGISTER_OP("TPUReplicateMetadata")
.Attr("num_replicas: int >= 0")
.Attr("topology: string = \"\"")
+ .Attr("use_tpu: bool = true")
.Attr("device_assignment: list(int) = []")
.Attr("computation_shape: list(int) = []")
.Attr("host_compute_core: list(string) = []")
.Attr("computation: func")
.Attr("num_replicas: int >= 1")
.Attr("topology: string = \"\"")
+ .Attr("use_tpu: bool = true")
.Attr("device_assignment: list(int) = []")
.Attr("host_compute_core: list(string) = []")
.Attr("computation_shape: list(int) = []")
num_replicas: the number of replicas of the computation to run.
topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU
topology.
+use_tpu: a bool indicating if this computation will run on TPU or CPU/GPU.
+Currently, only supports a default placement (computation is placed on GPU
+if one is available, and on CPU if not).
computation_shape: a [mesh_dimension] array describing the shape of each
computation replica in numbers of cores in the TPU mesh.
device_assignment: a flattened array with shape
inputs=None,
infeed_queue=None,
device_assignment=None,
- name=None):
+ name=None,
+ use_tpu=True):
"""Builds graph operators that runs compilation and replicated computation.
This is a lower level interface than replicate that returns a separate compile
only one core, and there is either only one replica, or the number of
replicas is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing.
+ use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
+ backends. Currently, only supports a default placement (computation is
+ placed on GPU if one is available, and on CPU if not).
Returns:
A list of lists with the first list corresponding to the compile op and the
second a list of output tensors, indexed by `[replica_num][output_num]`.
context.Enter()
metadata = tpu_ops.tpu_replicate_metadata(
- num_replicas=num_replicas, **metadata_kwargs)
+ num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
with tpu_function.tpu_shard_context(
num_replicas), ops.control_dependencies([metadata]):
for i in xrange(output_arity)]
with ops.control_dependencies([metadata]):
- compile_status = tpu_ops.tpu_compilation_result()
- op = compile_status.op
- attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
- op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access
+ if use_tpu:
+ compile_status = tpu_ops.tpu_compilation_result()
+ op = compile_status.op
+ attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
+ op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access
+ else:
+ compile_status = control_flow_ops.no_op(name="compilation_status")
with ops.control_dependencies(output_operations):
if output_arity == 0: