const Tensor& host_tensor, Tensor* device_tensor) const {
xla::Literal literal;
TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal));
- VLOG(1) << "Transfer to device as literal: " << literal.ToString();
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(device_tensor)->shaped_buffer();
+ VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
+ << shaped_buffer.ToString();
return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal,
shaped_buffer);
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
transfer_manager_->TransferLiteralFromDevice(
stream_->parent(), shaped_buffer));
- VLOG(1) << "Transfer from device as literal: " << literal->ToString();
+ VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
+ << shaped_buffer.ToString();
Tensor tensor;
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
<< " "
<< reinterpret_cast<const void*>(
device_tensor->tensor_data().data())
- << " " << cpu_tensor->NumElements();
+ << " " << cpu_tensor->NumElements() << " "
+ << cpu_tensor->shape().DebugString() << " "
+ << device_tensor->shape().DebugString();
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
const int64 total_bytes = cpu_tensor->TotalBytes();
device_tensor->tensor_data().data())
<< " "
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
- << device_tensor->NumElements();
+ << " " << device_tensor->NumElements() << " "
+ << cpu_tensor->shape().DebugString() << " "
+ << device_tensor->shape().DebugString();
const int64 total_bytes = cpu_tensor->TotalBytes();
se::DeviceMemoryBase dev_src_ptr =
namespace tensorflow {
-/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) {
+/*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
if (tensor->NumElements() == 0) {
return nullptr;
}
return xla_tensor;
}
-/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
- return FromTensor(const_cast<Tensor*>(tensor));
+/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
+ return tensor.RefCountIsOne();
}
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
index_to_buffer.second = buffer.Forget();
}
+ VLOG(4) << shaped_buffer.ToString();
+
set_shaped_buffer(std::move(shaped_buffer));
return Status::OK();
}
public:
// Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast
// fails.
- static XlaTensor* FromTensor(Tensor* tensor);
- // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast
- // fails.
- static const XlaTensor* FromTensor(const Tensor* tensor);
+ static XlaTensor* FromTensor(const Tensor* tensor);
+
+ static bool RefCountIsOne(const Tensor& tensor);
// Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
// which case the returned value is shaped_buffer()->root_buffer(), or a
CHECK(has_shaped_buffer());
return *shaped_buffer_;
}
+ xla::ShapedBuffer& shaped_buffer() {
+ CHECK(has_shaped_buffer());
+ return *shaped_buffer_;
+ }
// Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
"}");
}
+bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
+ if (size() != other.size()) {
+ return false;
+ }
+ for (auto it = begin(), other_it = other.begin(); it != end();
+ ++it, ++other_it) {
+ if (*it != *other_it) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
+ return !(*this == other);
+}
+
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
out << shape_index.ToString();
return out;
return ShapeIndexView(new_begin, end_);
}
+ bool operator==(const ShapeIndexView& other) const;
+ bool operator!=(const ShapeIndexView& other) const;
+
string ToString() const;
private:
friend class TensorTestHelper; // For access to set_shape
friend class OpKernelContext; // For access to RefCountIsOne().
friend class ScopedAllocator; // For access to buf_.
+ friend class XlaTensor; // For access to RefCountIsOne().
friend class XlaTensorBuffer; // For access to the private constructor taking
// the buffer
template <typename Device, typename T>