"tf_cuda_cc_test",
"tf_cc_test",
"tf_copts",
+ "tfe_xla_copts",
"tf_cuda_library",
)
"c_api_internal.h",
],
hdrs = ["c_api.h"],
- copts = tf_copts(),
+ copts = tf_copts() + tfe_xla_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
- }) + ["//tensorflow/core:gpu_runtime"],
+ }) + select({
+ "//tensorflow:with_xla_support": [
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/jit",
+ ],
+ "//conditions:default": [],
+ }) + [
+ "//tensorflow/core:gpu_runtime",
+ ],
)
tf_cuda_library(
tf_cuda_cc_test(
name = "c_api_test",
srcs = ["c_api_test.cc"],
+ extra_copts = tfe_xla_copts(),
tags = [
"guitar",
"multi_gpu",
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/runtime.h"
+#ifdef TENSORFLOW_EAGER_USE_XLA
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#endif // TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
+bool IsXLA(tensorflow::Device* d) {
+ if (d == nullptr) return false;
+ const auto& device_type = d->attributes().device_type();
+ return device_type.find("XLA") != std::string::npos;
+}
+
string DeviceName(tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
(srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
const bool dst_cpu = IsCPU(dstd);
const bool src_cpu = IsCPU(srcd);
- if (is_same_device) {
+ // both_on_cpu can be true and yet is_same_device is false, if one of src/dst
+ // has device type XLA_CPU, and the other CPU.
+ const bool both_on_cpu = src_cpu && dst_cpu;
+ if (is_same_device || both_on_cpu) {
return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
}
tensorflow::Tensor* src = &(h->t);
"//conditions:default": ["-pthread"]
}))
+
+def tfe_xla_copts():
+ return select({
+ "//tensorflow:with_xla_support": ["-DTENSORFLOW_EAGER_USE_XLA"],
+ "//conditions:default": [],
+ })
+
def tf_opts_nortti_if_android():
return if_android([
"-fno-rtti",
tags=[],
data=[],
size="medium",
+ extra_copts=[],
linkstatic=0,
args=[],
linkopts=[]):
tags=tags + ["manual"],
data=data,
size=size,
+ extra_copts=extra_copts,
linkstatic=linkstatic,
linkopts=linkopts,
args=args)
tags=tags + tf_cuda_tests_tags(),
data=data,
size=size,
+ extra_copts=extra_copts,
linkopts=linkopts,
args=args)