}
/* static */
-LocalShapedBuffer* LocalShapedBuffer::FromLiteral(
+StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
const Literal& argument,
const tensorflow::gtl::optional<Shape>& shape_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
- ScopedShapedBuffer buf = [&] {
+ StatusOr<ScopedShapedBuffer> buf = [&] {
if (shape_with_layout) {
std::unique_ptr<Literal> relaid =
argument.Relayout(shape_with_layout.value());
- return ToBuffer(client, /*device_ordinal=*/0, *relaid)
- .ConsumeValueOrDie();
+ return ToBuffer(client, /*device_ordinal=*/0, *relaid);
}
- return ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie();
+ return ToBuffer(client, /*device_ordinal=*/0, argument);
}();
- return new LocalShapedBuffer(std::move(buf));
+ TF_RETURN_IF_ERROR(buf.status());
+ return new LocalShapedBuffer(std::move(buf).ValueOrDie());
}
-std::unique_ptr<Literal> LocalShapedBuffer::ToLiteral() const {
+StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
LocalClient* client = GetOrCreateLocalClient();
- return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie();
+ return client->ShapedBufferToLiteral(*shaped_buffer());
}
CompiledLocalComputation::CompiledLocalComputation(
// client.
class LocalShapedBuffer {
public:
- static LocalShapedBuffer* FromLiteral(
+ static StatusOr<LocalShapedBuffer*> FromLiteral(
const Literal& argument,
const tensorflow::gtl::optional<Shape>& shape_with_layout);
+
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
const ScopedShapedBuffer* shaped_buffer() const;
- std::unique_ptr<Literal> ToLiteral() const;
+
+ StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
private:
ScopedShapedBuffer shaped_buffer_;
}
}
+%typemap(out) StatusOr<xla::swig::LocalShapedBuffer*> {
+ if ($1.ok()) {
+ auto* value = $1.ValueOrDie();
+ {
+ auto* $1 = value;
+ $typemap(out, xla::swig::LocalShapedBuffer*)
+ }
+ } else {
+ PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
+ SWIG_fail;
+ }
+}
+
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
if ($1.ok()) {
std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();