From 0898ee302cb20d9fce50dae4f484816a2dc2d0e2 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 26 Feb 2018 11:57:30 -0800 Subject: [PATCH] Use optimized ops to handle GPU memory swapping: this avoids the need for 2 pairs of extra _send/_recv nodes which speeds things up a bit. This also ensures that performance doesn't depend on the recv scheduling built in TF, which isn't always optimal. PiperOrigin-RevId: 187057831 --- tensorflow/core/grappler/optimizers/BUILD | 36 ++++++++- .../grappler/optimizers/gpu_swapping_kernels.cc | 88 ++++++++++++++++++++++ .../core/grappler/optimizers/gpu_swapping_ops.cc | 58 ++++++++++++++ .../core/grappler/optimizers/memory_optimizer.cc | 9 ++- .../grappler/optimizers/memory_optimizer_test.cc | 65 ++++++++++++---- tensorflow/core/grappler/utils/BUILD | 1 + tensorflow/core/grappler/utils/grappler_test.cc | 17 +++++ tensorflow/core/grappler/utils/grappler_test.h | 3 + 8 files changed, 258 insertions(+), 19 deletions(-) create mode 100644 tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc create mode 100644 tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 50ba48e..908e58b 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -1,6 +1,8 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") filegroup( name = "all_files", @@ -282,18 +284,48 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "gpu_swapping_kernels", + srcs = [ + "gpu_swapping_kernels.cc", + ], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "gpu_swapping_ops", + srcs = [ + "gpu_swapping_ops.cc", + ], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + cc_library( name = "memory_optimizer", - srcs = ["memory_optimizer.cc"], + srcs = [ + "memory_optimizer.cc", + ], hdrs = [ "memory_optimizer.h", ], visibility = ["//visibility:public"], deps = [ + ":gpu_swapping_kernels", + ":gpu_swapping_ops", ":graph_optimizer", ":graph_rewriter", ":static_schedule", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", @@ -307,7 +339,7 @@ cc_library( ], ) -tf_cc_test( +tf_cc_test_gpu( name = "memory_optimizer_test", srcs = ["memory_optimizer_test.cc"], deps = [ diff --git a/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc new file mode 100644 index 0000000..1820af6 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op kernels used to swap data in and out of GPU memory. + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +class CopyFromGpuToHostKernel : public AsyncOpKernel { + public: + explicit CopyFromGpuToHostKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor& input = ctx->input(0); + OP_REQUIRES_ASYNC( + ctx, !ctx->input_alloc_attr(0).on_host(), + errors::Internal("The input tensor to the _CopyFromGpuToHost kernel " + "must reside on the device."), + done); + + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_on_host(true); + Tensor* output; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->allocate_output(0, input.shape(), &output, alloc_attrs), + done); + + ctx->op_device_context()->CopyDeviceTensorToCPU( + &input, "CopyFromGpuToHost", static_cast(ctx->device()), + output, [ctx, done](const Status& s) { + ctx->SetStatus(s); + done(); + }); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("_CopyFromGpuToHost").Device(DEVICE_GPU).HostMemory("output"), + CopyFromGpuToHostKernel); + +class CopyFromHostToGpuKernel : public AsyncOpKernel { + public: + explicit CopyFromHostToGpuKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor& input = ctx->input(0); + OP_REQUIRES_ASYNC( + ctx, ctx->input_alloc_attr(0).on_host(), + errors::Internal("The input tensor to the _CopyFromHostToGpu kernel " + "must reside on the host."), + done); + + Tensor* output; + OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, input.shape(), &output), + done); + + ctx->op_device_context()->CopyCPUTensorToDevice( + &input, static_cast(ctx->device()), output, + [ctx, done](const Status& s) { + ctx->SetStatus(s); + done(); + }); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("_CopyFromHostToGpu").Device(DEVICE_GPU).HostMemory("input"), + CopyFromHostToGpuKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc b/tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc new file mode 100644 index 0000000..4682834 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Definition for the ops used to swap data in and out of GPU memory. + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +// The _CopyFromGpuToHost op copies its input tensor to the host. The input must +// reside on GPU. The op itself must be placed on GPU. +REGISTER_OP("_CopyFromGpuToHost") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr) { + c->set_output_handle_shapes_and_types(0, *handle_data); + } + return Status::OK(); + }) + .Doc("Copies the input tensor from gpu to the host."); + +// The _CopyFromHostToGpu op copies its input tensor from the host to the GPU. +// The input must reside on CPU. The op itself must be placed on GPU. +REGISTER_OP("_CopyFromHostToGpu") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr) { + c->set_output_handle_shapes_and_types(0, *handle_data); + } + return Status::OK(); + }) + .Doc("Copies the input tensor from the host to the GPU."); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index dec4f04..694139f 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -720,18 +720,19 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap, // Force the tensor to be copied to cpu. NodeDef* swap_out_node = graph->add_node(); swap_out_node->set_name(swap_out_name); - swap_out_node->set_op("Identity"); - swap_out_node->set_device("/device:CPU:0"); + swap_out_node->set_op("_CopyFromGpuToHost"); // Force the tensor to be restored to the device. NodeDef* swap_in_node = graph->add_node(); swap_in_node->set_name(swap_in_name); - swap_in_node->set_op("Identity"); + swap_in_node->set_op("_CopyFromHostToGpu"); *swap_in_node->add_input() = swap_out_node->name(); - // Colocate the swap_in_ node with the node itself. + // Colocate the swap_out_ and swap_in_ nodes with the node itself. + swap_out_node->set_device(node->device()); swap_in_node->set_device(node->device()); string coloc_group = strings::StrCat("loc@", tensor_to_swap); + (*swap_out_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index 5d7913e..9595936 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -221,16 +221,20 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) { // Build a simple graph with an op that's marked for swapping. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Variable(s.WithOpName("a"), {10, 10}, DT_FLOAT); - Output b = ops::AddN(s.WithOpName("b"), {a}); - Output c = ops::AddN(s.WithOpName("c"), {b}); - Output d = ops::AddN(s.WithOpName("d"), {c}); - Output e = ops::AddN(s.WithOpName("e"), {b, d}); + Output a = + ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"), {10, 10}, DT_FLOAT); + Output b = ops::AddN(s.WithOpName("b").WithDevice("/gpu:0"), {a}); + Output c = ops::AddN(s.WithOpName("c").WithDevice("/gpu:0"), {b}); + Output d = ops::AddN(s.WithOpName("d").WithDevice("/gpu:0"), {c}); + Output e = ops::AddN(s.WithOpName("e").WithDevice("/gpu:0"), {b, d}); + + Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {10, 10}); + Output init = ops::Assign(s.WithOpName("init"), a, constant); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - EXPECT_EQ(5, item.graph.node_size()); + EXPECT_EQ(7, item.graph.node_size()); EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name()); AttrValue& val = (*item.graph.mutable_node(4)->mutable_attr())["_swap_to_host"]; @@ -243,32 +247,43 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) { Status status = optimizer.Optimize(cluster.get(), item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(7, output.node_size()); - const NodeDef& new_e = output.node(4); + EXPECT_EQ(9, output.node_size()); + const NodeDef& new_e = output.node(6); EXPECT_EQ(NodeName(e.name()), new_e.name()); EXPECT_EQ(2, new_e.input_size()); EXPECT_EQ(NodeName(d.name()), new_e.input(1)); EXPECT_EQ("swap_in_e_0", new_e.input(0)); - const NodeDef& swap_out = output.node(5); + const NodeDef& swap_out = output.node(7); EXPECT_EQ("swap_out_e_0", swap_out.name()); + EXPECT_EQ("_CopyFromGpuToHost", swap_out.op()); - const NodeDef& swap_in = output.node(6); + const NodeDef& swap_in = output.node(8); EXPECT_EQ("swap_in_e_0", swap_in.name()); + EXPECT_EQ("_CopyFromHostToGpu", swap_in.op()); EXPECT_EQ(NodeName(b.name()), swap_out.input(0)); EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0)); EXPECT_EQ("^c", swap_in.input(1)); - const NodeDef& new_c = output.node(2); + const NodeDef& new_c = output.node(4); EXPECT_EQ(NodeName(c.name()), new_c.name()); EXPECT_EQ("^swap_out_e_0", new_c.input(1)); // Run the optimizer a second time to ensure it's idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(cluster.get(), item, &output); + GrapplerItem item_copy(item, std::move(output)); + status = optimizer.Optimize(cluster.get(), item_copy, &output); TF_EXPECT_OK(status); + +#if GOOGLE_CUDA + item.fetch = {"e"}; + item.init_ops = {init.name()}; + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual(tensors_expected[0], tensors[0]); +#endif } TEST_F(MemoryOptimizerTest, SwappingHeuristics) { @@ -287,9 +302,13 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) { Output h = ops::Exp(s.WithOpName("h").WithDevice("/gpu:0"), c); Output i = ops::Log(s.WithOpName("i").WithDevice("/gpu:0"), d); + Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {128, 128, 8}); + Output init = ops::Assign(s.WithOpName("init"), v, constant); + GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"e", "f", "g", "h", "i"}; + item.init_ops = {init.name()}; std::unique_ptr cluster(CreateVirtualCluster()); @@ -308,6 +327,15 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) { EXPECT_EQ("axis", node.input(4)); } } + +#if GOOGLE_CUDA + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + for (int i = 0; i < item.fetch.size(); ++i) { + test::ExpectTensorEqual(tensors_expected[i], tensors[i]); + } +#endif } TEST_F(MemoryOptimizerTest, UnswappableInputs) { @@ -325,9 +353,13 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { Output e = ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis); + Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {128, 128, 8}); + Output init = ops::Assign(s.WithOpName("init"), v, constant); + GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"e"}; + item.init_ops = {init.name()}; std::unique_ptr cluster(CreateVirtualCluster()); @@ -344,6 +376,13 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { EXPECT_EQ("^swap_out_d_2", node.input(4)); } } + +#if GOOGLE_CUDA + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual(tensors_expected[0], tensors[0]); +#endif } TEST_F(MemoryOptimizerTest, AccumulationRewrites) { diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 0a9dbe2..5d32609 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -142,6 +142,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", + "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", ], ) diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index fed46c0..fef8e97 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -35,6 +35,23 @@ std::vector GrapplerTest::EvaluateNodes( return output_tensors; } +std::vector GrapplerTest::EvaluateFetchNodes(const GrapplerItem& item) { + SessionOptions options; + std::unique_ptr session(NewSession(options)); + TF_CHECK_OK(session->Create(item.graph)); + RunOptions run_options; + if (!item.init_ops.empty()) { + std::vector dummy; + TF_CHECK_OK( + session->Run(run_options, {}, {}, item.init_ops, &dummy, nullptr)); + } + std::vector output_tensors; + TF_CHECK_OK( + session->Run(run_options, {}, item.fetch, {}, &output_tensors, nullptr)); + TF_CHECK_OK(session->Close()); + return output_tensors; +} + void GrapplerTest::AddNode(const string& name, const string& op, const std::vector& inputs, GraphDef* graph) { auto* node = graph->add_node(); diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index 042b616..fd6809b 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -30,6 +31,8 @@ class GrapplerTest : public ::testing::Test { std::vector EvaluateNodes(const GraphDef& graph, const std::vector& node_names); + std::vector EvaluateFetchNodes(const GrapplerItem& item); + void AddNode(const string& name, const string& op, const std::vector& inputs, GraphDef* graph); -- 2.7.4