From fa7b5a9d1ab654bbd466487e39a8b3f83c17f3f0 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Thu, 3 May 2018 18:08:34 -0700 Subject: [PATCH] [XLA] Make LocalShapedBuffer::FromLiteral fallible by passing StatusOr wrapper. PiperOrigin-RevId: 195345724 --- .../compiler/xla/python/local_computation_builder.cc | 16 ++++++++-------- .../compiler/xla/python/local_computation_builder.h | 6 ++++-- .../compiler/xla/python/local_computation_builder.i | 13 +++++++++++++ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 7102f46..0444581 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -104,25 +104,25 @@ static StatusOr ToBuffer(LocalClient* client, } /* static */ -LocalShapedBuffer* LocalShapedBuffer::FromLiteral( +StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); - ScopedShapedBuffer buf = [&] { + StatusOr buf = [&] { if (shape_with_layout) { std::unique_ptr relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, *relaid) - .ConsumeValueOrDie(); + return ToBuffer(client, /*device_ordinal=*/0, *relaid); } - return ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); + return ToBuffer(client, /*device_ordinal=*/0, argument); }(); - return new LocalShapedBuffer(std::move(buf)); + TF_RETURN_IF_ERROR(buf.status()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -std::unique_ptr LocalShapedBuffer::ToLiteral() const { +StatusOr> LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); - return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie(); + return client->ShapedBufferToLiteral(*shaped_buffer()); } CompiledLocalComputation::CompiledLocalComputation( diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index e104890..5ec0978 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -59,12 +59,14 @@ StatusOr > TransferFromOutfeedLocalReplica( // client. class LocalShapedBuffer { public: - static LocalShapedBuffer* FromLiteral( + static StatusOr FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout); + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; - std::unique_ptr ToLiteral() const; + + StatusOr > ToLiteral() const; private: ScopedShapedBuffer shaped_buffer_; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index ac792e8..b8cce5a 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -205,6 +205,19 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBuffer*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + %typemap(out) StatusOr< std::unique_ptr > { if ($1.ok()) { std::unique_ptr value = $1.ConsumeValueOrDie(); -- 2.7.4