Internal change
authorMingsheng Hong <hongm@google.com>
Tue, 6 Feb 2018 19:51:18 +0000 (11:51 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Feb 2018 20:14:38 +0000 (12:14 -0800)
PiperOrigin-RevId: 184715822

tensorflow/c/eager/BUILD
tensorflow/c/eager/c_api.cc
tensorflow/tensorflow.bzl

index e62310d811462f88af93505393b622d9a87c72d3..3505f70dc156605f3f35a5b12834f83994eee043 100644 (file)
@@ -6,6 +6,7 @@ load(
     "tf_cuda_cc_test",
     "tf_cc_test",
     "tf_copts",
+    "tfe_xla_copts",
     "tf_cuda_library",
 )
 
@@ -16,7 +17,7 @@ tf_cuda_library(
         "c_api_internal.h",
     ],
     hdrs = ["c_api.h"],
-    copts = tf_copts(),
+    copts = tf_copts() + tfe_xla_copts(),
     visibility = ["//visibility:public"],
     deps = select({
         "//tensorflow:android": [
@@ -33,7 +34,15 @@ tf_cuda_library(
             "//tensorflow/core:lib_internal",
             "//tensorflow/core:protos_all_cc",
         ],
-    }) + ["//tensorflow/core:gpu_runtime"],
+    }) + select({
+        "//tensorflow:with_xla_support": [
+            "//tensorflow/compiler/tf2xla:xla_compiler",
+            "//tensorflow/compiler/jit",
+        ],
+        "//conditions:default": [],
+    }) + [
+        "//tensorflow/core:gpu_runtime",
+    ],
 )
 
 tf_cuda_library(
@@ -56,6 +65,7 @@ tf_cuda_library(
 tf_cuda_cc_test(
     name = "c_api_test",
     srcs = ["c_api_test.cc"],
+    extra_copts = tfe_xla_copts(),
     tags = [
         "guitar",
         "multi_gpu",
index d65b592895950cea3b528478e5bd6257ac688cc6..3a6d2ce45bfe8cc86cdbfaa6702e2ce2e46be65d 100644 (file)
@@ -25,6 +25,9 @@ limitations under the License.
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/eager/c_api_internal.h"
 #include "tensorflow/c/eager/runtime.h"
+#ifdef TENSORFLOW_EAGER_USE_XLA
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#endif  // TENSORFLOW_EAGER_USE_XLA
 #include "tensorflow/core/common_runtime/copy_tensor.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
@@ -48,6 +51,12 @@ bool IsCPU(tensorflow::Device* d) {
   return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
 }
 
+bool IsXLA(tensorflow::Device* d) {
+  if (d == nullptr) return false;
+  const auto& device_type = d->attributes().device_type();
+  return device_type.find("XLA") != std::string::npos;
+}
+
 string DeviceName(tensorflow::Device* d) {
   return (d == nullptr) ? "cpu:0" : d->name();
 }
@@ -183,7 +192,10 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
       (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
   const bool dst_cpu = IsCPU(dstd);
   const bool src_cpu = IsCPU(srcd);
-  if (is_same_device) {
+  // both_on_cpu can be true and yet is_same_device is false, if one of src/dst
+  // has device type XLA_CPU, and the other CPU.
+  const bool both_on_cpu = src_cpu && dst_cpu;
+  if (is_same_device || both_on_cpu) {
     return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
   }
   tensorflow::Tensor* src = &(h->t);
index 7fe9c98726798abe29c00b5f40b338baa92fda60..bf4a9fe6cea594b649419b3520ec943c92fc34e3 100644 (file)
@@ -219,6 +219,13 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False):
             "//conditions:default": ["-pthread"]
       }))
 
+
+def tfe_xla_copts():
+  return select({
+      "//tensorflow:with_xla_support": ["-DTENSORFLOW_EAGER_USE_XLA"],
+      "//conditions:default": [],
+  })
+
 def tf_opts_nortti_if_android():
   return if_android([
       "-fno-rtti",
@@ -666,6 +673,7 @@ def tf_cuda_cc_test(name,
                     tags=[],
                     data=[],
                     size="medium",
+                    extra_copts=[],
                     linkstatic=0,
                     args=[],
                     linkopts=[]):
@@ -676,6 +684,7 @@ def tf_cuda_cc_test(name,
       tags=tags + ["manual"],
       data=data,
       size=size,
+      extra_copts=extra_copts,
       linkstatic=linkstatic,
       linkopts=linkopts,
       args=args)
@@ -696,6 +705,7 @@ def tf_cuda_cc_test(name,
       tags=tags + tf_cuda_tests_tags(),
       data=data,
       size=size,
+      extra_copts=extra_copts,
       linkopts=linkopts,
       args=args)