From 2f5f2cb4253b4eaf7953cf7ed28f76e0bdee6fcc Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 11 May 2018 16:04:54 -0700 Subject: [PATCH] [XLA] s/tensorflow::Status/Status/. These are type aliases of one another; we'd like to be consistent and use the shorter one. PiperOrigin-RevId: 196322955 --- tensorflow/compiler/xla/BUILD | 3 +- tensorflow/compiler/xla/client/client.cc | 6 +- tensorflow/compiler/xla/client/global_data.cc | 2 +- tensorflow/compiler/xla/client/local_client.cc | 8 +- tensorflow/compiler/xla/client/local_client.h | 8 +- tensorflow/compiler/xla/layout_util.cc | 22 ++- tensorflow/compiler/xla/layout_util.h | 11 +- tensorflow/compiler/xla/rpc/grpc_service.cc | 4 +- tensorflow/compiler/xla/rpc/grpc_stub.cc | 116 ++++++------ tensorflow/compiler/xla/rpc/grpc_stub.h | 121 ++++++------- .../compiler/xla/service/allocation_tracker.cc | 6 +- tensorflow/compiler/xla/service/buffer_liveness.cc | 4 +- tensorflow/compiler/xla/service/buffer_liveness.h | 2 +- .../compiler/xla/service/compile_only_service.h | 40 ++--- .../xla/service/cpu/cpu_layout_assignment.cc | 2 +- .../compiler/xla/service/cpu/dot_op_emitter.cc | 14 +- .../compiler/xla/service/cpu/dot_op_emitter.h | 8 +- .../compiler/xla/service/device_memory_allocator.h | 6 +- .../compiler/xla/service/execution_tracker.cc | 8 +- .../compiler/xla/service/execution_tracker.h | 4 +- .../compiler/xla/service/gpu/buffer_allocations.cc | 2 +- .../compiler/xla/service/gpu/buffer_allocations.h | 3 +- tensorflow/compiler/xla/service/gpu/copy_thunk.cc | 8 +- tensorflow/compiler/xla/service/gpu/copy_thunk.h | 8 +- tensorflow/compiler/xla/service/gpu/fft_thunk.cc | 6 +- tensorflow/compiler/xla/service/gpu/fft_thunk.h | 4 +- tensorflow/compiler/xla/service/gpu/for_thunk.cc | 12 +- tensorflow/compiler/xla/service/gpu/for_thunk.h | 8 +- tensorflow/compiler/xla/service/gpu/gemm_thunk.cc | 6 +- tensorflow/compiler/xla/service/gpu/gemm_thunk.h | 4 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 9 +- .../compiler/xla/service/gpu/kernel_thunk.cc | 12 +- tensorflow/compiler/xla/service/gpu/kernel_thunk.h | 8 +- .../gpu/llvm_gpu_backend/gpu_backend_lib.cc | 13 +- .../compiler/xla/service/gpu/sequential_thunk.cc | 10 +- .../compiler/xla/service/gpu/sequential_thunk.h | 8 +- tensorflow/compiler/xla/service/gpu/thunk.h | 10 +- tensorflow/compiler/xla/service/gpu/tuple_thunk.cc | 6 +- tensorflow/compiler/xla/service/gpu/tuple_thunk.h | 4 +- .../compiler/xla/service/gpu/while_transformer.cc | 12 +- tensorflow/compiler/xla/service/hlo_verifier.cc | 38 ++-- tensorflow/compiler/xla/service/hlo_verifier.h | 4 +- .../compiler/xla/service/layout_assignment_test.cc | 13 +- .../xla/service/llvm_ir/fused_ir_emitter.cc | 2 +- .../compiler/xla/service/llvm_ir/loop_emitter.cc | 9 +- .../compiler/xla/service/llvm_ir/loop_emitter.h | 5 +- tensorflow/compiler/xla/service/service.cc | 200 ++++++++++----------- tensorflow/compiler/xla/service/service.h | 133 ++++++-------- tensorflow/compiler/xla/service/shape_inference.cc | 30 ++-- .../compiler/xla/service/transpose_folding.cc | 2 +- tensorflow/compiler/xla/service_interface.h | 114 ++++++------ tensorflow/compiler/xla/shape_layout.cc | 8 +- tensorflow/compiler/xla/shape_layout.h | 4 +- tensorflow/compiler/xla/status.h | 2 +- tensorflow/compiler/xla/statusor_test.cc | 2 +- tensorflow/compiler/xla/test_helpers.h | 29 ++- .../compiler/xla/tests/client_library_test_base.cc | 26 ++- .../compiler/xla/tests/client_library_test_base.h | 8 +- .../compiler/xla/tests/local_client_test_base.cc | 3 +- .../compiler/xla/tests/local_client_test_base.h | 3 +- tensorflow/compiler/xla/tests/params_test.cc | 2 +- tensorflow/compiler/xla/text_literal_writer.cc | 4 +- tensorflow/compiler/xla/text_literal_writer.h | 4 +- .../compiler/xla/tools/parser/hlo_parser_test.cc | 20 +-- 64 files changed, 558 insertions(+), 655 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 729480e..4304045 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -99,9 +99,9 @@ cc_library( hdrs = ["service_interface.h"], visibility = [":friends"], deps = [ + ":status", ":xla_data_proto", ":xla_proto", - "//tensorflow/core:lib", ], ) @@ -245,6 +245,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protobuf_util", + ":status", ":status_macros", ":statusor", ":types", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 328e1b8..0a79b3c 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -336,7 +336,7 @@ StatusOr>> Client::ExecuteParallel( ExecuteParallelResponse response; VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteParallel(&request, &response); + Status s = stub_->ExecuteParallel(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -372,7 +372,7 @@ StatusOr>> Client::ExecuteParallel( ExecuteParallelResponse response; VLOG(1) << "making execute-graph-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + Status s = stub_->ExecuteGraphParallel(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -401,7 +401,7 @@ StatusOr> Client::GetDeviceHandles( GetDeviceHandlesResponse response; VLOG(1) << "making get device request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->GetDeviceHandles(&request, &response); + Status s = stub_->GetDeviceHandles(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index 40f59ea..2986d40 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -31,7 +31,7 @@ GlobalData::~GlobalData() { *request.mutable_data() = handle_; UnregisterResponse response; VLOG(1) << "requesting to unregister " << handle_.ShortDebugString(); - tensorflow::Status s = parent_->Unregister(&request, &response); + Status s = parent_->Unregister(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1acc6f8..9d44d3a 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -48,7 +48,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, << "Must have a valid device ordinal that the executable was built for."; } -tensorflow::Status LocalExecutable::ValidateExecutionOptions( +Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& host_computation_layout = @@ -207,7 +207,7 @@ StatusOr LocalExecutable::ExecuteAndDump( return std::move(result); } -tensorflow::Status LocalExecutable::RecordArguments( +Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, SessionModule* session_module) { session_module->clear_arguments(); @@ -219,8 +219,8 @@ tensorflow::Status LocalExecutable::RecordArguments( return Status::OK(); } -tensorflow::Status LocalExecutable::RecordResult( - const ShapedBuffer* result, SessionModule* session_module) { +Status LocalExecutable::RecordResult(const ShapedBuffer* result, + SessionModule* session_module) { session_module->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index d8fd7a5..3195037 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -59,7 +59,7 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. - tensorflow::Status ValidateExecutionOptions( + Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend); @@ -71,13 +71,13 @@ class LocalExecutable { // Records the arguments used to invoke the computation in a SessionModule // proto. - tensorflow::Status RecordArguments( + Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, SessionModule* session_module); // Records the result of the computation in a SessionModule proto. - tensorflow::Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, + SessionModule* session_module); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index c6f8f67..a76fdcd 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -140,8 +140,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape( - const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape. if (shape.has_layout()) { @@ -150,12 +149,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : shape.tuple_shapes()) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } - return tensorflow::Status::OK(); + return Status::OK(); } else if (ShapeUtil::IsOpaque(shape)) { if (shape.has_layout()) { return InvalidArgument("opaque should not have a layout field"); } - return tensorflow::Status::OK(); + return Status::OK(); } else { // Array shape. if (!shape.has_layout()) { @@ -166,14 +165,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape( - const Layout& layout, const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, + const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } if (ShapeUtil::IsOpaque(shape)) { - return tensorflow::Status::OK(); + return Status::OK(); } if (layout.format() == INVALID_FORMAT) { @@ -225,7 +224,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } - return tensorflow::Status::OK(); + return Status::OK(); } /* static */ void LayoutUtil::ClearLayout(Shape* shape) { @@ -384,7 +383,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { namespace { // Internal helper for recursively copying layouts. -tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { +Status CopyLayoutInternal(const Shape& src, Shape* dst) { if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); @@ -411,14 +410,13 @@ tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { dst->clear_layout(); } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace /* static */ -tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, - Shape* dst) { +Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return CopyLayoutInternal(src, dst); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6cec750..d3d6a2c 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -61,12 +61,12 @@ class LayoutUtil { static void SetToDefaultLayout(ProgramShape* program_shape); // Validates that the layout within the given shape is correct. - static tensorflow::Status ValidateLayoutInShape(const Shape& shape); + static Status ValidateLayoutInShape(const Shape& shape); // Validates that the provided layout satisfies invariants for the given // shape. - static tensorflow::Status ValidateLayoutForShape(const Layout& layout, - const Shape& shape); + static Status ValidateLayoutForShape(const Layout& layout, + const Shape& shape); // Clears the layout in the given Shape. After this function is called, // HasLayout will return false for the shape. @@ -179,8 +179,7 @@ class LayoutUtil { // tuples. 'src' and 'dst' need not be compatible but the two shapes must // have the same tuple structure (if any) and arrays must have the same // rank. within the shapes must have the same number of dimensions. - static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src, - Shape* dst); + static Status CopyLayoutBetweenShapes(const Shape& src, Shape* dst); // Returns true if the layouts of lhs and rhs are equal, false // otherwise. Recursively compares layouts of tuples. diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index ffb72fc..5f4dc6b 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -27,8 +27,8 @@ namespace xla { return std::move(grpc_service); } -::grpc::Status DelegateRPC(std::function op) { - tensorflow::Status s = op(); +::grpc::Status DelegateRPC(std::function op) { + Status s = op(); return tensorflow::ToGrpcStatus(s); } diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index e1f2b0a..620ac6c 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -20,53 +20,49 @@ namespace xla { GRPCStub::~GRPCStub() = default; -tensorflow::Status MakeRPC( +Status MakeRPC( const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) { ::grpc::ClientContext context; ::grpc::Status s = rpc_method(&context); return tensorflow::FromGrpcStatus(s); } -tensorflow::Status GRPCStub::TransferToClient( - const TransferToClientRequest* request, - TransferToClientResponse* response) { +Status GRPCStub::TransferToClient(const TransferToClientRequest* request, + TransferToClientResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToClient(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToServer( - const TransferToServerRequest* request, - TransferToServerResponse* response) { +Status GRPCStub::TransferToServer(const TransferToServerRequest* request, + TransferToServerResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToServer(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToInfeed( - const TransferToInfeedRequest* request, - TransferToInfeedResponse* response) { +Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request, + TransferToInfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToInfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferFromOutfeed( - const TransferFromOutfeedRequest* request, - TransferFromOutfeedResponse* response) { +Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request, + TransferFromOutfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferFromOutfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, - ResetDeviceResponse* response) { +Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, + ResetDeviceResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ResetDevice(context, *request, response); }); } -tensorflow::Status GRPCStub::LoadComputationSnapshot( +Status GRPCStub::LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -74,28 +70,28 @@ tensorflow::Status GRPCStub::LoadComputationSnapshot( }); } -tensorflow::Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { +Status GRPCStub::Execute(const ExecuteRequest* request, + ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Execute(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteGraph(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteParallel( - const ExecuteParallelRequest* request, ExecuteParallelResponse* response) { +Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request, + ExecuteParallelResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteParallel(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteGraphParallel( +Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -103,38 +99,35 @@ tensorflow::Status GRPCStub::ExecuteGraphParallel( }); } -tensorflow::Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { +Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, + ExecuteAsyncResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteAsync(context, *request, response); }); } -tensorflow::Status GRPCStub::WaitForExecution( - const WaitForExecutionRequest* request, - WaitForExecutionResponse* response) { +Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, + WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->WaitForExecution(context, *request, response); }); } -tensorflow::Status GRPCStub::DeconstructTuple( - const DeconstructTupleRequest* request, - DeconstructTupleResponse* response) { +Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, + DeconstructTupleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->DeconstructTuple(context, *request, response); }); } -tensorflow::Status GRPCStub::GetComputationStats( - const ComputationStatsRequest* request, - ComputationStatsResponse* response) { +Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request, + ComputationStatsResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetComputationStats(context, *request, response); }); } -tensorflow::Status GRPCStub::GetComputationGraphStats( +Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -142,81 +135,77 @@ tensorflow::Status GRPCStub::GetComputationGraphStats( }); } -tensorflow::Status GRPCStub::GetComputationShape( - const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { +Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request, + GetComputationShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetComputationShape(context, *request, response); }); } -tensorflow::Status GRPCStub::GetShape(const GetShapeRequest* request, - GetShapeResponse* response) { +Status GRPCStub::GetShape(const GetShapeRequest* request, + GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetShape(context, *request, response); }); } -tensorflow::Status GRPCStub::GetDeviceHandles( - const GetDeviceHandlesRequest* request, - GetDeviceHandlesResponse* response) { +Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request, + GetDeviceHandlesResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetDeviceHandles(context, *request, response); }); } -tensorflow::Status GRPCStub::CreateChannelHandle( - const CreateChannelHandleRequest* request, - CreateChannelHandleResponse* response) { +Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, + CreateChannelHandleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->CreateChannelHandle(context, *request, response); }); } // Methods used by ComputationBuilder. -tensorflow::Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { +Status GRPCStub::Computation(const ComputationRequest* request, + ComputationResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Computation(context, *request, response); }); } -tensorflow::Status GRPCStub::Op(const OpRequest* request, - OpResponse* response) { +Status GRPCStub::Op(const OpRequest* request, OpResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->CreateOp(context, *request, response); }); } -tensorflow::Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { +Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, + GetLocalShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetLocalShape(context, *request, response); }); } -tensorflow::Status GRPCStub::SetReturnValue( - const SetReturnValueRequest* request, SetReturnValueResponse* responses) { +Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request, + SetReturnValueResponse* responses) { return MakeRPC([this, request, responses](::grpc::ClientContext* context) { return grpc_stub_->SetReturnValue(context, *request, responses); }); } -tensorflow::Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { +Status GRPCStub::IsConstant(const IsConstantRequest* request, + IsConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->IsConstant(context, *request, response); }); } -tensorflow::Status GRPCStub::ComputeConstant( - const ComputeConstantRequest* request, ComputeConstantResponse* response) { +Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request, + ComputeConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ComputeConstant(context, *request, response); }); } -tensorflow::Status GRPCStub::ComputeConstantGraph( +Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -225,17 +214,16 @@ tensorflow::Status GRPCStub::ComputeConstantGraph( } // Methods used by Computation. -tensorflow::Status GRPCStub::SnapshotComputation( - const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { +Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request, + SnapshotComputationResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->SnapshotComputation(context, *request, response); }); } // Methods used by GlobalData. -tensorflow::Status GRPCStub::Unregister(const UnregisterRequest* request, - UnregisterResponse* response) { +Status GRPCStub::Unregister(const UnregisterRequest* request, + UnregisterResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Unregister(context, *request, response); }); diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index fd9810d..5906d45 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -28,105 +28,90 @@ class GRPCStub : public ServiceInterface { explicit GRPCStub(grpc::XlaService::Stub* stub) : grpc_stub_(stub) {} ~GRPCStub() override; - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; - tensorflow::Status LoadComputationSnapshot( + Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* result) override; - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) override; - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override; - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* request, - ExecuteParallelResponse* response) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) override; - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* request, - ComputationStatsResponse* response) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) override; - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; + Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Methods used by ComputationBuilder. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; + Status Computation(const ComputationRequest* arg, + ComputationResponse* result) override; - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status Op(const OpRequest* arg, OpResponse* result) override; + Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; + Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) override; - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Methods used by Computation. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; + Status SnapshotComputation(const SnapshotComputationRequest* ag, + SnapshotComputationResponse* result) override; // Methods used by GlobalData. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; grpc::XlaService::Stub* service() { return grpc_stub_; } diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index eb52803..95b4cb6 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -101,7 +101,7 @@ StatusOr AllocationTracker::RegisterInternal( return result; } -tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { +Status AllocationTracker::Unregister(const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; @@ -130,7 +130,7 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { for (auto& shaped_buffer : it->second) { shaped_buffer.reset(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( @@ -242,7 +242,7 @@ Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory, } else { allocation.ref_count--; } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 37982aa..acb546a 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -44,7 +44,7 @@ StatusOr> BufferLiveness::Run( return std::move(liveness); } -tensorflow::Status BufferLiveness::Analyze() { +Status BufferLiveness::Analyze() { TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto* computation : module_->computations()) { if (computation->IsFusionComputation()) { @@ -71,7 +71,7 @@ tensorflow::Status BufferLiveness::Analyze() { } XLA_VLOG_LINES(3, ToString()); - return tensorflow::Status::OK(); + return Status::OK(); } string BufferLiveness::ToString() const { diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 11834a5..cdd3cf4 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -89,7 +89,7 @@ class BufferLiveness { // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. - tensorflow::Status Analyze(); + Status Analyze(); // Returns true if the live range of the buffer of 'a' is strictly before the // live range of the buffer of 'b' (they do not overlap). diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index c10609e..7f2ce0e 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -75,48 +75,42 @@ class CompileOnlyService : public Service { // Override Service methods that require or imply the existence of an // execute backend. Note that this does not include TransferToClient, as // computing constants produces global data that we may wish to transfer. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override { + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override { + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override { + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index 85c461e..aa872d5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -179,7 +179,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 81c0d67..5cdfc11 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -542,7 +542,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( +/* static */ Status DotOpEmitter::EmitDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -691,7 +691,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { return true; } -tensorflow::Status DotOpEmitter::Emit() { +Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. // @@ -869,10 +869,10 @@ tensorflow::Status DotOpEmitter::Emit() { // loop. ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitScalarDot() { +Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; llvm::Value* lhs_value = @@ -897,10 +897,10 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitCallToRuntime() { +Status DotOpEmitter::EmitCallToRuntime() { // The signature of the Eigen runtime matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, @@ -1002,7 +1002,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { ir_builder_->getInt64(mat_mult_dims.k), ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); - return tensorflow::Status::OK(); + return Status::OK(); } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index e5ede06..566f07b 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -57,7 +57,7 @@ class DotOpEmitter { // dimensions as the result, and the result is computed as `addend_array` + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported // for Matrix-vector products. - static tensorflow::Status EmitDotOperation( + static Status EmitDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -76,18 +76,18 @@ class DotOpEmitter { const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. - tensorflow::Status Emit(); + Status Emit(); // Emits instructions to perform a scalar dot product (a multiply of the // LHS and RHS) and store the results in the target. - tensorflow::Status EmitScalarDot(); + Status EmitScalarDot(); // Emit an LLVM IR implementation of the dot operation if we can. Returns // true if an LLVM IR implementation was emitted. bool EmitLlvmIrDotIfProfitable(); // Emits a call to the CPU runtime to perform the matrix multiply. - tensorflow::Status EmitCallToRuntime(); + Status EmitCallToRuntime(); // Emits a series of nested loops for iterating over an operand array in the // dot operation. Loops are constructed in major to minor dimension layout diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index 5feb650..d87b86c 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -60,8 +60,7 @@ class DeviceMemoryAllocator { } // Must be a nop for null pointers. - virtual tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase mem) = 0; + virtual Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) = 0; // Return the platform that the allocator allocates memory on. const se::Platform* platform() const { return platform_; } @@ -89,8 +88,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { // Pull in two-arg overload that sets retry_on_failure to true. using DeviceMemoryAllocator::Allocate; - tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase mem) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; bool AllowsAsynchronousDeallocation() const override; diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 2f0b9ed..6794cfe 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -37,11 +37,11 @@ AsyncExecution::AsyncExecution(Backend* backend, } } -tensorflow::Status AsyncExecution::BlockUntilDone() const { +Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } - return tensorflow::Status::OK(); + return Status::OK(); } ExecutionTracker::ExecutionTracker() : next_handle_(1) {} @@ -61,7 +61,7 @@ ExecutionHandle ExecutionTracker::Register( return execution_handle; } -tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { +Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { @@ -69,7 +69,7 @@ tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { handle.handle()); } handle_to_execution_.erase(handle.handle()); - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr ExecutionTracker::Resolve( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 5b6bddf..4458152 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -43,7 +43,7 @@ class AsyncExecution { AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); - tensorflow::Status BlockUntilDone() const; + Status BlockUntilDone() const; const GlobalDataHandle& result() const { return result_; } @@ -77,7 +77,7 @@ class ExecutionTracker { GlobalDataHandle data); // Unregisters the execution for the given handle. - tensorflow::Status Unregister(const ExecutionHandle& handle); + Status Unregister(const ExecutionHandle& handle); // Resolves the given ExecutionHandle to an AsyncExecution. Returns an // error status if the given handle is not found, which means that the diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index cb66d37..ab5149d 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -116,7 +116,7 @@ BufferAllocations::~BufferAllocations() { } } -tensorflow::Status BufferAllocations::TearDown( +Status BufferAllocations::TearDown( const std::set& live_addresses) { // Deallocate temporary buffers, taking care to try to deallocate all of them // even if one of the deallocations fails. diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index a36571d..6366235 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -78,8 +78,7 @@ class BufferAllocations { // Tears down all buffers allocated by this object that are not in // `live_addresses`. - tensorflow::Status TearDown( - const std::set& live_addresses); + Status TearDown(const std::set& live_addresses); private: BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal, diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index bf912fb..ee38c03 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -29,12 +29,12 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream( +Status HostToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( @@ -46,14 +46,14 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream( +Status DeviceToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); stream->ThenMemcpy(&destination_data, source_data, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 2e7eb5f..8b12838 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -39,8 +39,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const void* source_address_; @@ -62,8 +62,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 1cea493..e14ee69 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -106,8 +106,8 @@ FftThunk::FftThunk(FftType fft_type, input_shape_(input_shape), output_shape_(output_shape) {} -tensorflow::Status FftThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -207,7 +207,7 @@ tensorflow::Status FftThunk::ExecuteOnStream( LOG(FATAL) << "unsupported fft type"; } if (launch_ok) { - return tensorflow::Status::OK(); + return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, FftTypeToString(fft_type_).c_str()); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index ea4270a..b0a2256 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -71,8 +71,8 @@ class FftThunk : public Thunk { FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ // Does the FFT for the thunk on "stream". - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const se::fft::Type fft_type_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index c49c273..b36539e 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -30,20 +30,20 @@ ForThunk::ForThunk(const int64 loop_limit, body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable, - se::StreamExecutor* executor) { +Status ForThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ForThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { for (int64 i = 0; i < loop_limit_; ++i) { // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 56c5c49..41ddfe0 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -36,10 +36,10 @@ class ForThunk : public Thunk { ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable, - se::StreamExecutor* executor) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index f996fe4..2ebb40a 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -232,8 +232,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, output_shape_(output_shape), alpha_(alpha) {} -tensorflow::Status GemmThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -350,7 +350,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (!launch_ok) { return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index f42cbf9..7a4830d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -47,8 +47,8 @@ class GemmThunk : public Thunk { GemmThunk& operator=(const GemmThunk&) = delete; // Does the gemm operation for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; // Returns true if we'll perform autotuning if run on the given stream. If // so, we want the GPU to be quiescent during autotuning, so as not to diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 4fdc4c8..df494a1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -128,9 +128,8 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { +Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -283,12 +282,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } - return tensorflow::Status::OK(); + return Status::OK(); } // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { +Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 3baee22..f56c1ce 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -35,8 +35,8 @@ KernelThunk::KernelThunk( kernel_name_(kernel_name), unroll_factor_(unroll_factor) {} -tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable, - se::StreamExecutor* executor) { +Status KernelThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { tensorflow::mutex_lock lock(mutex_); if (!loader_spec_) { loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); @@ -66,7 +66,7 @@ tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable, } } - return tensorflow::Status::OK(); + return Status::OK(); } void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { @@ -74,8 +74,8 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { launch_dimensions_ = launch_dims; } -tensorflow::Status KernelThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; @@ -106,7 +106,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream( *kernel_args)) { return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 532f15e..7def27e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -57,12 +57,12 @@ class KernelThunk : public Thunk { int unroll_factor() const { return unroll_factor_; } void SetLaunchDimensions(const LaunchDimensions& launch_dims); - tensorflow::Status Initialize(const GpuExecutable& executable, - se::StreamExecutor* executor) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; // Executes the kernel for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // Buffers passed to the kernel as arguments. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index d70cb07..917c576 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -77,8 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = - tensorflow::Env::Default()->GetMatchingPaths( + const Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); if (status.ok() && unified_libdevice_files.size() == 1) { @@ -311,11 +310,11 @@ bool CouldNeedLibdevice(const llvm::Module& module) { } // Links libdevice into the given module if the module needs libdevice. -tensorflow::Status LinkLibdeviceIfNecessary( - llvm::Module* module, std::pair compute_capability, - const string& libdevice_dir_path) { +Status LinkLibdeviceIfNecessary(llvm::Module* module, + std::pair compute_capability, + const string& libdevice_dir_path) { if (!CouldNeedLibdevice(*module)) { - return tensorflow::Status::OK(); + return Status::OK(); } llvm::Linker linker(*module); @@ -336,7 +335,7 @@ tensorflow::Status LinkLibdeviceIfNecessary( return tensorflow::errors::Internal(tensorflow::strings::StrCat( "Error linking libdevice from ", libdevice_path)); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr CompileModuleToPtx(llvm::Module* module, diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index 849eff2..b50f5b5 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -24,20 +24,20 @@ SequentialThunk::SequentialThunk(std::vector>&& thunks, const HloInstruction* hlo) : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} -tensorflow::Status SequentialThunk::Initialize(const GpuExecutable& executable, - se::StreamExecutor* executor) { +Status SequentialThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { for (auto& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor)); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status SequentialThunk::ExecuteOnStream( +Status SequentialThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { for (const auto& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 8305791..3537110 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -38,10 +38,10 @@ class SequentialThunk : public Thunk { const std::vector>& thunks() const { return thunks_; } - tensorflow::Status Initialize(const GpuExecutable& executable, - se::StreamExecutor* executor) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index ff9b608..931c0bf 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -75,9 +75,9 @@ class Thunk { // This may be called multiple times. Its main purpose is to give us a chance // to do initialization outside of ExecuteOnStream() so that the // time spent initializing doesn't count towards our execution profile. - virtual tensorflow::Status Initialize(const GpuExecutable& /*executable*/, - se::StreamExecutor* /*executor*/) { - return tensorflow::Status::OK(); + virtual Status Initialize(const GpuExecutable& /*executable*/, + se::StreamExecutor* /*executor*/) { + return Status::OK(); } // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) @@ -97,8 +97,8 @@ class Thunk { // lifetime. Stream argument must be non-null. // // Precondition: Initialize(stream->parent()) has been called. - virtual tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) = 0; + virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) = 0; private: Kind kind_; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index ecb5485..97cb04c 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -20,8 +20,8 @@ limitations under the License. namespace xla { namespace gpu { -tensorflow::Status TupleThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { std::vector tuple_element_buffer_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { tuple_element_buffer_addresses.push_back( @@ -40,7 +40,7 @@ tensorflow::Status TupleThunk::ExecuteOnStream( tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), sizeof(void*) * tuple_element_buffer_addresses.size()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 8b459c2..951f809 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -45,8 +45,8 @@ class TupleThunk : public Thunk { TupleThunk(const TupleThunk&) = delete; TupleThunk& operator=(const TupleThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index e6caec8..ad55728 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -144,7 +144,7 @@ class ExprTree { TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), tagged_instructions)); } - return tensorflow::Status::OK(); + return Status::OK(); } private: @@ -169,7 +169,7 @@ class MatcherBase { // Attempts to match each ExprTree in 'expr_trees_'. // Returns OK on the first successful match, error status otherwise. - virtual tensorflow::Status Run() { + virtual Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { status = MatchExprTree(expr_tree); @@ -201,7 +201,7 @@ class MatcherBase { } else if (type == S64) { *const_value = literal.GetFirstElement(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr GetTaggedInstruction( @@ -315,7 +315,7 @@ class WhileConditionComputationMatcher : public MatcherBase { gte_fusion_param0->name().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; @@ -379,7 +379,7 @@ class WhileInitOperandMatcher : public MatcherBase { GetTaggedInstruction("loop_start", tagged_instructions)); TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - return tensorflow::Status::OK(); + return Status::OK(); } const HloInstruction* while_hlo_; @@ -477,7 +477,7 @@ class WhileBodyComputationMatcher : public MatcherBase { } } } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 096ebb7..7d6d0d9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -106,9 +106,7 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the @@ -127,12 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return tensorflow::Status::OK(); + return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -164,7 +160,7 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { @@ -183,7 +179,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { operand_shape.dimensions(operand_dimension)) << broadcast->ToString() << " operand shape " << operand_shape; } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { @@ -191,7 +187,7 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { @@ -201,21 +197,17 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { - return tensorflow::Status::OK(); + return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleCall(HloInstruction* call) { // The shape of kCall should match the shape of the computation it calls. return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { return CheckShape(slice, @@ -497,7 +489,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -547,7 +539,7 @@ Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, instr1->ToString().c_str(), instr1->channel_id(), instr2->ToString().c_str(), instr2->channel_id()); } - return tensorflow::Status::OK(); + return Status::OK(); } string ComputationsToString( @@ -612,7 +604,7 @@ Status VerifyHloStructure(HloModule* module) { } } } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { @@ -728,7 +720,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { @@ -777,7 +769,7 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { "init: %s, body: %s", init->ToString().c_str(), body_root->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { @@ -795,7 +787,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { ShapeUtil::HumanString(operand_shape).c_str()); } } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr HloVerifier::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 6208887..1392a78 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -82,9 +82,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); - } + Status FinishVisit(HloInstruction*) override { return Status::OK(); } protected: // Check the instruction's shape against the shape given by ShapeInference diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 7e1bb11..986e177 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -660,13 +660,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - EXPECT_EQ( - ::tensorflow::Status::OK(), - backend() - .compiler() - ->RunBackend(std::move(module), backend().default_stream_executor(), - /*device_allocator=*/nullptr) - .status()); + EXPECT_EQ(Status::OK(), backend() + .compiler() + ->RunBackend(std::move(module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); } // A GTE inside of a fusion node inherits the layout of its operand (which diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index bc683a1..f172b1d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -151,7 +151,7 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { Status FusedIrEmitter::FinishVisit(HloInstruction* root) { fused_root_ = root; - return tensorflow::Status::OK(); + return Status::OK(); } FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 3978acc..0728ccf 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -39,14 +39,13 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { // Convert target_element_generator to a BodyEmitter. TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); target_array.EmitWriteArrayElement(array_index, target_element, ir_builder); - return tensorflow::Status::OK(); + return Status::OK(); }), shape_(target_array.GetShape()), ir_builder_(ir_builder) {} @@ -124,7 +123,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { for (const IrArray::Index& array_index : EmitIndexAndSetExitBasicBlock(loop_name)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); @@ -135,7 +134,7 @@ tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { if (exit_bb_ != nullptr) { ir_builder_->SetInsertPoint(exit_bb_); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 9ff497a..b70d28e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -38,8 +38,7 @@ using ElementGenerator = // Emits a loop for every element in the given shape. class LoopEmitter { public: - using BodyEmitter = - std::function; + using BodyEmitter = std::function; LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* ir_builder); @@ -72,7 +71,7 @@ class LoopEmitter { tensorflow::StringPiece loop_name); // Emits a complete loop nest for every element in the given shape. - tensorflow::Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = ""); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 495f880..047cadb 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -64,7 +64,7 @@ namespace { // Records the arguments used to invoke a computation in a SessionModule // proto. -tensorflow::Status RecordArguments( +Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, se::StreamExecutor* executor, TransferManager* transfer_manager, SessionModule* module) { @@ -75,24 +75,22 @@ tensorflow::Status RecordArguments( transfer_manager->TransferLiteralFromDevice(executor, *argument)); *module->add_arguments() = literal->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } // Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - SessionModule* module) { +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, SessionModule* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, transfer_manager->TransferLiteralFromDevice(executor, result)); *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } // Records the arguments used to invoke a computation in an HloSnapshot proto. -tensorflow::Status RecordArguments( +Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, se::StreamExecutor* executor, TransferManager* transfer_manager, HloSnapshot* module) { @@ -103,20 +101,18 @@ tensorflow::Status RecordArguments( transfer_manager->TransferLiteralFromDevice(executor, *argument)); *module->add_arguments() = literal->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } // Records the result of a computation in a HloSnapshot proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - HloSnapshot* module) { +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, transfer_manager->TransferLiteralFromDevice(executor, result)); *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace @@ -199,8 +195,8 @@ Service::Service(const ServiceOptions& options, } } -tensorflow::Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { +Status Service::Computation(const ComputationRequest* arg, + ComputationResponse* result) { if (arg->name().empty()) { return InvalidArgument("computation request needs a name"); } @@ -210,24 +206,23 @@ tensorflow::Status Service::Computation(const ComputationRequest* arg, VLOG(1) << Printf("Created new computation %s on service %p, name %s", result->computation().ShortDebugString().c_str(), this, arg->name().c_str()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { +Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { +Status Service::Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) { return allocation_tracker_.Unregister(arg->data()); } // Deconstructs a previously-allocated global handle. -tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { +Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { TF_ASSIGN_OR_RETURN( std::vector elements, allocation_tracker_.DeconstructTuple(arg->tuple_handle())); @@ -235,11 +230,11 @@ tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, for (auto& element : elements) { *result->add_element_handles() = element; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const { +Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const { if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " @@ -511,7 +506,7 @@ Status Service::ValidateEntryComputationLayout(HloModule* module) { module->device_entry_computation_layout().result_shape(), execute_backend_->transfer_manager()->HostShapeToDeviceShape( module->host_entry_computation_layout().result_shape()))); - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> Service::BuildExecutable( @@ -801,8 +796,8 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { +Status Service::SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); return computation->SetReturnValue(arg->operand()); @@ -849,8 +844,8 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { +Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); std::vector>> all_arguments; @@ -957,11 +952,11 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, } VLOG(1) << "successfully completed 'execute-parallel' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { +Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; std::vector>> all_arguments; @@ -1058,11 +1053,11 @@ tensorflow::Status Service::ExecuteGraphParallel( } VLOG(1) << "successfully completed 'execute-graph-parallel' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { const int64 available_device_count = execute_backend_->device_count(); const int64 replica_count = options_.number_of_replicas(); if (replica_count <= 0) { @@ -1082,11 +1077,11 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, *result->add_device_handles() = device_handle; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result) { ExecuteParallelRequest parallel_arg; *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; @@ -1094,8 +1089,8 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, return PickParallelResponse(parallel_result, result); } -tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; @@ -1103,7 +1098,7 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, return PickParallelResponse(parallel_result, result); } -tensorflow::Status Service::PickParallelResponse( +Status Service::PickParallelResponse( const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { // The "result device" selection is a bit hacky, but better than assuming it // is device 0. We have b/76035356 for restructuring the client API to clean @@ -1126,8 +1121,7 @@ tensorflow::Status Service::PickParallelResponse( return Status::OK(); } -tensorflow::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { +Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { VLOG(1) << "running execute request: " << arg->ShortDebugString(); TF_ASSIGN_OR_RETURN(UserComputation * user_computation, @@ -1198,7 +1192,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, } VLOG(1) << "successfully completed 'execute' request"; - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> Service::BuildExecutable( @@ -1243,8 +1237,8 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { VLOG(1) << "running execute-graph request"; if (!arg->has_computation()) { @@ -1303,11 +1297,11 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, } VLOG(1) << "successfully completed 'execute-graph' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { +Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) { VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); TF_ASSIGN_OR_RETURN(UserComputation * user_computation, @@ -1383,11 +1377,11 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, streams.clear(); VLOG(1) << "successfully completed 'execute-async' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) { +Status Service::WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, execution_tracker_.Resolve(arg->execution())); @@ -1398,11 +1392,11 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution())); VLOG(1) << "successfully completed 'wait-for-execution' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +Status Service::TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); @@ -1432,7 +1426,7 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, *result->mutable_literal() = result_literal->Relayout(*return_shape)->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } namespace { @@ -1450,8 +1444,8 @@ std::unique_ptr CloneShapedBufferOnDevice( } // namespace -tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { +Status Service::TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(arg->literal())); const Shape& shape = literal->shape(); @@ -1484,11 +1478,11 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, StrCat("TransferToServer literal of shape ", ShapeUtil::HumanString(shape)))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1517,9 +1511,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor, *literal); } -tensorflow::Status Service::TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1545,16 +1538,16 @@ tensorflow::Status Service::TransferFromOutfeed( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, arg->shape_with_layout(), &literal)); *result->mutable_literal() = literal.ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { +Status Service::ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { return execute_backend_->ResetDevices(); } -tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { +Status Service::IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1570,11 +1563,11 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, user_computation->IsConstant(arg->operand(), arg->num_parameters())); result->set_is_constant(is_constant); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { +Status Service::ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1661,11 +1654,11 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { +Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) { if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } @@ -1703,20 +1696,18 @@ tensorflow::Status Service::ComputeConstantGraph( } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); *result->mutable_shape() = buffer->on_host_shape(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { +Status Service::GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); @@ -1726,21 +1717,21 @@ tensorflow::Status Service::GetComputationShape( TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( versioned_handle.version)); *result->mutable_program_shape() = *program_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { +Status Service::GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); TF_ASSIGN_OR_RETURN(*result->mutable_shape(), computation->GetShape(arg->operand())); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) { +Status Service::GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1766,10 +1757,10 @@ tensorflow::Status Service::GetComputationStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationGraphStats( +Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { return InvalidArgument("Computations may not be empty."); @@ -1796,11 +1787,11 @@ tensorflow::Status Service::GetComputationGraphStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); + return Status::OK(); } template -tensorflow::Status Service::AddInstruction( +Status Service::AddInstruction( const RequestT* arg, ResponseT* result, const std::function(UserComputation*)>& adder) { @@ -1808,10 +1799,10 @@ tensorflow::Status Service::AddInstruction( computation_tracker_.Resolve(arg->computation())); TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { +Status Service::Op(const OpRequest* arg, OpResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); StatusOr handle_status; @@ -2033,27 +2024,26 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { if (arg->has_sharding()) { TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { +Status Service::SnapshotComputation(const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) { TF_ASSIGN_OR_RETURN( std::unique_ptr module, computation_tracker_.SnapshotComputation(arg->computation())); result->set_allocated_module(module.release()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::LoadComputationSnapshot( +Status Service::LoadComputationSnapshot( const LoadComputationSnapshotRequest* arg, LoadComputationSnapshotResponse* result) { TF_ASSIGN_OR_RETURN(*result->mutable_computation(), computation_tracker_.LoadSessionModule(arg->module())); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceHandle Service::SingleComputationDeviceHandle() const { diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index f84fe40..81fbd41 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -85,55 +85,52 @@ class Service : public ServiceInterface { // Creates a new computation with the given name. // A unique ComputationHandle is returned. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; + Status Computation(const ComputationRequest* arg, + ComputationResponse* result) override; // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; // Modifies the provided computation so that subsequent executions // will compute the provided ComputationDataHandle, rather than the // last expression enqueued on that Computation. - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; + Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; // Executes a computation with the provided global data passed as // immutable arguments. Returns global data output and execution timing. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; // Requests one or more device handles from the target. // @@ -143,9 +140,8 @@ class Service : public ServiceInterface { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; // Asynchronously executes a computation with provided arguments. Invokes // the provided computation with the provided global data passed as @@ -154,38 +150,33 @@ class Service : public ServiceInterface { // (Note: The corresponding function in xla::Client was removed as part of // b/64116060, in an attempt to simplify our API. We're keeping this around // for now in case we want to expose this to clients in a different way.) - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the // first call. - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; // Requests that global data be transferred to the client in literal form. - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; // Transfers data from a literal provided by the client, into device memory. - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; // Transfers data from the Outfeed othe device to the literal provided by the // client. - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -196,71 +187,65 @@ class Service : public ServiceInterface { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; // Tests if an expression is a compile-time constant. - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) override; // Computes the value of a constant expression. - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Returns the shape (with layout) of an array associated with a given data // handle. - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; // Returns the program shape of the computation associated with the given // handle. - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; + Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; ///// // Computation-oriented methods. // Enqueues an Op on the computation. - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; + Status Op(const OpRequest* arg, OpResponse* result) override; // Retrieves the inferred shape for a value within a computation. - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; // Retrieves the statistics of a computation. - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; // Retrieves the statistics of a computation. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) override; // Snapshots the current state of a computation handle into a serializable // protocol buffer form, so it can be loaded via // LoadComputationSnapshot. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; + Status SnapshotComputation(const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) override; // Loads a computation from a serialized protocol buffer created via // SnapshotComputation. - tensorflow::Status LoadComputationSnapshot( + Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* arg, LoadComputationSnapshotResponse* result) override; // Creates a unique channel handle that can be used for Send/Recv // instructions. - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Returns the ComputationTracker of the current service instance. // Only used in unit tests to access user computations from client. @@ -389,7 +374,7 @@ class Service : public ServiceInterface { // Convenience function for adding a function to a user computation. template - tensorflow::Status AddInstruction( + Status AddInstruction( const RequestT* arg, ResponseT* result, const std::function(UserComputation*)>& adder); @@ -397,16 +382,14 @@ class Service : public ServiceInterface { // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result); - tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result); + Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); + Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - tensorflow::Status ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const; + Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index fedb42a..3500978 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -172,8 +172,8 @@ bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { +Status ExpectNotTupleOrOpaque(const Shape& shape, + tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("Expected non-tuple argument for %s, but got %s.", std::string(op_type).c_str(), @@ -183,13 +183,13 @@ tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else { - return tensorflow::Status::OK(); + return Status::OK(); } } -tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + const Shape& init_value_shape, + const PrimitiveType& input_element_type) { if (reducer_shape.parameters_size() != 2) { return InvalidArgument( "Reduction function must take 2 parameters, but " @@ -249,7 +249,7 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, ShapeUtil::HumanString(accumulator_shape).c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr InferWindowOutputShape(const Shape& base_shape, @@ -1218,11 +1218,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( @@ -1324,15 +1324,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index f7a5512..ba16dc6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -215,7 +215,7 @@ StatusOr TransposeFolding::Run(HloModule* module) { std::make_pair(instruction, operand_indices)); } } - return tensorflow::Status::OK(); + return Status::OK(); }; for (auto* comp : module->MakeNonfusionComputations()) { diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 4f64fe8..141347a 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" namespace xla { @@ -32,99 +32,93 @@ class ServiceInterface { virtual ~ServiceInterface() = default; // TODO(b/31824348): Convert to use StatusOr. - virtual tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, TransferToClientResponse* result) = 0; + virtual Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) = 0; - virtual tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, TransferToServerResponse* result) = 0; + virtual Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) = 0; - virtual tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + virtual Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) = 0; - virtual tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) = 0; + virtual Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) = 0; - virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) = 0; + virtual Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) = 0; - virtual tensorflow::Status LoadComputationSnapshot( + virtual Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* result) = 0; - virtual tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; + virtual Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status ExecuteParallel( - const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + virtual Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) = 0; + virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; + virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) = 0; - virtual tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; + virtual Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) = 0; - virtual tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; + virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) = 0; - virtual tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; + virtual Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) = 0; - virtual tensorflow::Status GetComputationGraphStats( + virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; + virtual Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) = 0; - virtual tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) = 0; + virtual Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) = 0; - virtual tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) = 0; + virtual Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) = 0; - virtual tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; + virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) = 0; // Methods used by ComputationBuilder. - virtual tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; + virtual Status Computation(const ComputationRequest* arg, + ComputationResponse* result) = 0; - virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0; + virtual Status Op(const OpRequest* arg, OpResponse* result) = 0; - virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; + virtual Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) = 0; - virtual tensorflow::Status SetReturnValue( - const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0; + virtual Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) = 0; - virtual tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; + virtual Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) = 0; - virtual tensorflow::Status ComputeConstant( - const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; + virtual Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) = 0; - virtual tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) = 0; + virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) = 0; // Methods used by Computation. - virtual tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; + virtual Status SnapshotComputation(const SnapshotComputationRequest* ag, + SnapshotComputationResponse* result) = 0; // Methods used by GlobalData. - virtual tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) = 0; + virtual Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 789eba5..7ee366b 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -22,24 +22,24 @@ limitations under the License. namespace xla { -tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { +Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(other_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } shape_ = other_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { +Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } *to_shape = shape_; - return tensorflow::Status::OK(); + return Status::OK(); } void ShapeLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index a1dce75..36806da 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -40,7 +40,7 @@ class ShapeLayout { // Assigns the layouts in this ShapeLayout to the Layout fields of the given // shape. 'to_shape' and the shape of the ShapeLayout object must be // compatible. - tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; + Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible @@ -49,7 +49,7 @@ class ShapeLayout { // Copies the layout from the given shape into this ShapeLayout. 'other_shape' // must be compatible with the ShapeLayout's shape. - tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); + Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index 4eb3bf3..69abb51 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -21,7 +21,7 @@ limitations under the License. namespace xla { -using tensorflow::Status; +using tensorflow::Status; // TENSORFLOW_STATUS_OK } // namespace xla diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index 7d76370..377a618 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -413,7 +413,7 @@ TEST(StatusOr, TestPointerValueConst) { EXPECT_EQ(&kI, thing.ValueOrDie()); } -// NOTE(tucker): tensorflow::StatusOr does not support this kind +// NOTE(tucker): StatusOr does not support this kind // of resize op. // TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) { // using EvilType = std::vector>; diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 17bae2e..8918350 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -40,13 +40,10 @@ class Literal; namespace testing { namespace internal_status { -inline const ::tensorflow::Status& GetStatus( - const ::tensorflow::Status& status) { - return status; -} +inline const Status& GetStatus(const Status& status) { return status; } template -inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { +inline const Status& GetStatus(const StatusOr& status) { return status.status(); } } // namespace internal_status @@ -57,21 +54,17 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { // The following macros are similar to macros in gmock, but deliberately named // differently in order to avoid conflicts in files which include both. -// Macros for testing the results of functions that return tensorflow::Status or +// Macros for testing the results of functions that return Status or // StatusOr (for any type T). -#define EXPECT_IS_OK(expression) \ - EXPECT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) -#define EXPECT_IS_NOT_OK(expression) \ - EXPECT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_NOT_OK(expression) \ + EXPECT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_OK -#define ASSERT_IS_OK(expression) \ - ASSERT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_NOT_OK -#define ASSERT_IS_NOT_OK(expression) \ - ASSERT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_NOT_OK(expression) \ + ASSERT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index b68f309..bf8ed4d 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -177,8 +177,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( error, shape_with_layout)); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function arguments, const std::function choose; - choose = [&, this](int64 index) -> tensorflow::Status { + std::function choose; + choose = [&, this](int64 index) -> Status { if (index < arguments.size()) { // Try out all layouts for the operand. TF_ASSIGN_OR_RETURN(auto literal, @@ -229,7 +227,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); layout_strings.pop_back(); - return tensorflow::Status::OK(); + return Status::OK(); } std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); @@ -247,7 +245,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( layout_strings.pop_back(); } while ( std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } // Every argument has an assigned layout. @@ -262,13 +260,13 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( tensorflow::strings::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); - return tensorflow::Status::OK(); + return Status::OK(); }; return choose(0); } -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { @@ -323,10 +321,10 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { @@ -376,7 +374,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); - return tensorflow::Status::OK(); + return Status::OK(); } void ClientLibraryTestBase::ComputeAndCompareR1U8( diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index c8c3af0..0499fec 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -188,11 +188,11 @@ class ClientLibraryTestBase : public ::testing::Test { const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. - tensorflow::Status ComputeAndCompareLiteralWithStatus( + Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - tensorflow::Status ComputeAndCompareLiteralWithStatus( + Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -378,12 +378,12 @@ class ClientLibraryTestBase : public ::testing::Test { ExecutionOptions execution_options_; private: - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output); - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function TestAllocator::Allocate(int device_ordinal, retry_on_failure); } -tensorflow::Status TestAllocator::Deallocate(int device_ordinal, - se::DeviceMemoryBase mem) { +Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { VLOG(2) << "Deallocate(" << device_ordinal << ")"; { tensorflow::mutex_lock lock(count_mutex_); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 6374c79..2582265 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -48,8 +48,7 @@ class TestAllocator : public StreamExecutorMemoryAllocator { StatusOr Allocate(int device_ordinal, uint64 size, bool retry_on_failure) override; - tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase mem) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. int64 allocation_count() const; diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index f04db77..838f1b4 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -160,7 +160,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) { auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); auto computation_status = builder.Build(); - ASSERT_NE(computation_status.status(), tensorflow::Status::OK()); + ASSERT_NE(computation_status.status(), Status::OK()); } XLA_TEST_F(ParamsTest, UnusedParameter) { diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 6e3061b..373c0d2 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -30,7 +30,7 @@ limitations under the License. namespace xla { -/* static */ tensorflow::Status TextLiteralWriter::WriteToPath( +/* static */ Status TextLiteralWriter::WriteToPath( const Literal& literal, tensorflow::StringPiece path) { std::unique_ptr f; auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); @@ -43,7 +43,7 @@ namespace xla { return s; } - tensorflow::Status status; + Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( [f_ptr, &status](tensorflow::gtl::ArraySlice indices, diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 7375493..0a1235b 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -37,8 +37,8 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static tensorflow::Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, + tensorflow::StringPiece path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index e100d8c..131aded 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -938,13 +938,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOpcode) { @@ -958,7 +958,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongShape) { @@ -970,7 +970,7 @@ ENTRY %blabla (x: g32[]) -> g32[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOperandsSize) { @@ -983,7 +983,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, OperandNotFound) { @@ -994,7 +994,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, MoreConstants) { @@ -1036,7 +1036,7 @@ ENTRY %some_2 () -> f32[2] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); } @@ -1050,7 +1050,7 @@ ENTRY %some_2x3 () -> f32[2,3] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); } @@ -1064,7 +1064,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); } @@ -1079,7 +1079,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); } -- 2.7.4