Add a method XlaTensor:ReleaseShapedBuffer() to relinquish the shaped buffer owned...
authorPeter Hawkins <phawkins@google.com>
Wed, 23 May 2018 13:45:12 +0000 (06:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 13:48:15 +0000 (06:48 -0700)
Add an equality operator for xla::ShapeIndexView.

PiperOrigin-RevId: 197716313

tensorflow/compiler/jit/xla_device_context.cc
tensorflow/compiler/jit/xla_tensor.cc
tensorflow/compiler/jit/xla_tensor.h
tensorflow/compiler/xla/shape_util.cc
tensorflow/compiler/xla/shape_util.h
tensorflow/core/framework/tensor.h

index ff30b62..c764834 100644 (file)
@@ -60,10 +60,11 @@ Status XlaTransferManager::TransferLiteralToDevice(
     const Tensor& host_tensor, Tensor* device_tensor) const {
   xla::Literal literal;
   TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal));
-  VLOG(1) << "Transfer to device as literal: " << literal.ToString();
 
   const xla::ShapedBuffer& shaped_buffer =
       XlaTensor::FromTensor(device_tensor)->shaped_buffer();
+  VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
+          << shaped_buffer.ToString();
   return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal,
                                                     shaped_buffer);
 }
@@ -76,7 +77,8 @@ Status XlaTransferManager::TransferLiteralFromDevice(
   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
                       transfer_manager_->TransferLiteralFromDevice(
                           stream_->parent(), shaped_buffer));
-  VLOG(1) << "Transfer from device as literal: " << literal->ToString();
+  VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
+          << shaped_buffer.ToString();
   Tensor tensor;
   TF_RETURN_IF_ERROR(
       LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
@@ -98,7 +100,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
             << " "
             << reinterpret_cast<const void*>(
                    device_tensor->tensor_data().data())
-            << " " << cpu_tensor->NumElements();
+            << " " << cpu_tensor->NumElements() << " "
+            << cpu_tensor->shape().DebugString() << " "
+            << device_tensor->shape().DebugString();
 
     void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
     const int64 total_bytes = cpu_tensor->TotalBytes();
@@ -165,7 +169,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
                    device_tensor->tensor_data().data())
             << " "
             << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
-            << device_tensor->NumElements();
+            << " " << device_tensor->NumElements() << " "
+            << cpu_tensor->shape().DebugString() << " "
+            << device_tensor->shape().DebugString();
 
     const int64 total_bytes = cpu_tensor->TotalBytes();
     se::DeviceMemoryBase dev_src_ptr =
index a7211c9..3c44c4a 100644 (file)
@@ -18,7 +18,7 @@ limitations under the License.
 
 namespace tensorflow {
 
-/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) {
+/*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
   if (tensor->NumElements() == 0) {
     return nullptr;
   }
@@ -27,8 +27,8 @@ namespace tensorflow {
   return xla_tensor;
 }
 
-/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
-  return FromTensor(const_cast<Tensor*>(tensor));
+/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
+  return tensor.RefCountIsOne();
 }
 
 /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
@@ -67,6 +67,8 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
     index_to_buffer.second = buffer.Forget();
   }
 
+  VLOG(4) << shaped_buffer.ToString();
+
   set_shaped_buffer(std::move(shaped_buffer));
   return Status::OK();
 }
index 6b29c82..c54001a 100644 (file)
@@ -34,10 +34,9 @@ class XlaTensor {
  public:
   // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast
   // fails.
-  static XlaTensor* FromTensor(Tensor* tensor);
-  // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast
-  // fails.
-  static const XlaTensor* FromTensor(const Tensor* tensor);
+  static XlaTensor* FromTensor(const Tensor* tensor);
+
+  static bool RefCountIsOne(const Tensor& tensor);
 
   // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
   // which case the returned value is shaped_buffer()->root_buffer(), or a
@@ -62,6 +61,10 @@ class XlaTensor {
     CHECK(has_shaped_buffer());
     return *shaped_buffer_;
   }
+  xla::ShapedBuffer& shaped_buffer() {
+    CHECK(has_shaped_buffer());
+    return *shaped_buffer_;
+  }
   // Mutates the XlaTensor to set the ShapedBuffer.
   void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
     shaped_buffer_ =
index 7a897f6..2cdee30 100644 (file)
@@ -55,6 +55,23 @@ string ShapeIndexView::ToString() const {
       "}");
 }
 
+bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
+  if (size() != other.size()) {
+    return false;
+  }
+  for (auto it = begin(), other_it = other.begin(); it != end();
+       ++it, ++other_it) {
+    if (*it != *other_it) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
+  return !(*this == other);
+}
+
 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
   out << shape_index.ToString();
   return out;
index cb8bf5a..73e0148 100644 (file)
@@ -132,6 +132,9 @@ class ShapeIndexView {
     return ShapeIndexView(new_begin, end_);
   }
 
+  bool operator==(const ShapeIndexView& other) const;
+  bool operator!=(const ShapeIndexView& other) const;
+
   string ToString() const;
 
  private:
index 58fbced..d2f2609 100644 (file)
@@ -484,6 +484,7 @@ class Tensor {
   friend class TensorTestHelper;      // For access to set_shape
   friend class OpKernelContext;       // For access to RefCountIsOne().
   friend class ScopedAllocator;       // For access to buf_.
+  friend class XlaTensor;             // For access to RefCountIsOne().
   friend class XlaTensorBuffer;  // For access to the private constructor taking
                                  // the buffer
   template <typename Device, typename T>