[XLA] Make LocalShapedBuffer::FromLiteral fallible by passing StatusOr wrapper.
authorChris Leary <leary@google.com>
Fri, 4 May 2018 01:08:34 +0000 (18:08 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 17:39:52 +0000 (10:39 -0700)
PiperOrigin-RevId: 195345724

tensorflow/compiler/xla/python/local_computation_builder.cc
tensorflow/compiler/xla/python/local_computation_builder.h
tensorflow/compiler/xla/python/local_computation_builder.i

index 7102f46..0444581 100644 (file)
@@ -104,25 +104,25 @@ static StatusOr<ScopedShapedBuffer> ToBuffer(LocalClient* client,
 }
 
 /* static */
-LocalShapedBuffer* LocalShapedBuffer::FromLiteral(
+StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
     const Literal& argument,
     const tensorflow::gtl::optional<Shape>& shape_with_layout) {
   LocalClient* client = GetOrCreateLocalClient();
-  ScopedShapedBuffer buf = [&] {
+  StatusOr<ScopedShapedBuffer> buf = [&] {
     if (shape_with_layout) {
       std::unique_ptr<Literal> relaid =
           argument.Relayout(shape_with_layout.value());
-      return ToBuffer(client, /*device_ordinal=*/0, *relaid)
-          .ConsumeValueOrDie();
+      return ToBuffer(client, /*device_ordinal=*/0, *relaid);
     }
-    return ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie();
+    return ToBuffer(client, /*device_ordinal=*/0, argument);
   }();
-  return new LocalShapedBuffer(std::move(buf));
+  TF_RETURN_IF_ERROR(buf.status());
+  return new LocalShapedBuffer(std::move(buf).ValueOrDie());
 }
 
-std::unique_ptr<Literal> LocalShapedBuffer::ToLiteral() const {
+StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
   LocalClient* client = GetOrCreateLocalClient();
-  return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie();
+  return client->ShapedBufferToLiteral(*shaped_buffer());
 }
 
 CompiledLocalComputation::CompiledLocalComputation(
index e104890..5ec0978 100644 (file)
@@ -59,12 +59,14 @@ StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
 // client.
 class LocalShapedBuffer {
  public:
-  static LocalShapedBuffer* FromLiteral(
+  static StatusOr<LocalShapedBuffer*> FromLiteral(
       const Literal& argument,
       const tensorflow::gtl::optional<Shape>& shape_with_layout);
+
   LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
   const ScopedShapedBuffer* shaped_buffer() const;
-  std::unique_ptr<Literal> ToLiteral() const;
+
+  StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
 
  private:
   ScopedShapedBuffer shaped_buffer_;
index ac792e8..b8cce5a 100644 (file)
@@ -205,6 +205,19 @@ tensorflow::ImportNumpy();
   }
 }
 
+%typemap(out) StatusOr<xla::swig::LocalShapedBuffer*> {
+  if ($1.ok()) {
+    auto* value = $1.ValueOrDie();
+    {
+      auto* $1 = value;
+      $typemap(out, xla::swig::LocalShapedBuffer*)
+    }
+  } else {
+    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
+    SWIG_fail;
+  }
+}
+
 %typemap(out) StatusOr< std::unique_ptr<Literal> > {
   if ($1.ok()) {
     std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();