Use optimized ops to handle GPU memory swapping: this avoids the need for 2
authorBenoit Steiner <bsteiner@google.com>
Mon, 26 Feb 2018 19:57:30 +0000 (11:57 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
pairs of extra _send/_recv nodes which speeds things up a bit. This also
ensures that performance doesn't depend on the recv scheduling built in TF,
which isn't always optimal.

PiperOrigin-RevId: 187057831

tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc [new file with mode: 0644]
tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc [new file with mode: 0644]
tensorflow/core/grappler/optimizers/memory_optimizer.cc
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
tensorflow/core/grappler/utils/BUILD
tensorflow/core/grappler/utils/grappler_test.cc
tensorflow/core/grappler/utils/grappler_test.h

index 50ba48e..908e58b 100644 (file)
@@ -1,6 +1,8 @@
 licenses(["notice"])  # Apache 2.0
 
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
 
 filegroup(
     name = "all_files",
@@ -282,18 +284,48 @@ tf_cc_test(
     ],
 )
 
+tf_kernel_library(
+    name = "gpu_swapping_kernels",
+    srcs = [
+        "gpu_swapping_kernels.cc",
+    ],
+    deps = [
+        "//tensorflow/core:core_cpu_base",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
+    name = "gpu_swapping_ops",
+    srcs = [
+        "gpu_swapping_ops.cc",
+    ],
+    deps = [
+        "//tensorflow/core:core_cpu_base",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+    alwayslink = 1,
+)
+
 cc_library(
     name = "memory_optimizer",
-    srcs = ["memory_optimizer.cc"],
+    srcs = [
+        "memory_optimizer.cc",
+    ],
     hdrs = [
         "memory_optimizer.h",
     ],
     visibility = ["//visibility:public"],
     deps = [
+        ":gpu_swapping_kernels",
+        ":gpu_swapping_ops",
         ":graph_optimizer",
         ":graph_rewriter",
         ":static_schedule",
         "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:graph_view",
         "//tensorflow/core/grappler:grappler_item",
@@ -307,7 +339,7 @@ cc_library(
     ],
 )
 
-tf_cc_test(
+tf_cc_test_gpu(
     name = "memory_optimizer_test",
     srcs = ["memory_optimizer_test.cc"],
     deps = [
diff --git a/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc
new file mode 100644 (file)
index 0000000..1820af6
--- /dev/null
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Op kernels used to swap data in and out of GPU memory.
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace {
+
+class CopyFromGpuToHostKernel : public AsyncOpKernel {
+ public:
+  explicit CopyFromGpuToHostKernel(OpKernelConstruction* context)
+      : AsyncOpKernel(context) {}
+  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+    const Tensor& input = ctx->input(0);
+    OP_REQUIRES_ASYNC(
+        ctx, !ctx->input_alloc_attr(0).on_host(),
+        errors::Internal("The input tensor to the _CopyFromGpuToHost kernel "
+                         "must reside on the device."),
+        done);
+
+    AllocatorAttributes alloc_attrs;
+    alloc_attrs.set_gpu_compatible(true);
+    alloc_attrs.set_on_host(true);
+    Tensor* output;
+    OP_REQUIRES_OK_ASYNC(
+        ctx, ctx->allocate_output(0, input.shape(), &output, alloc_attrs),
+        done);
+
+    ctx->op_device_context()->CopyDeviceTensorToCPU(
+        &input, "CopyFromGpuToHost", static_cast<Device*>(ctx->device()),
+        output, [ctx, done](const Status& s) {
+          ctx->SetStatus(s);
+          done();
+        });
+  }
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("_CopyFromGpuToHost").Device(DEVICE_GPU).HostMemory("output"),
+    CopyFromGpuToHostKernel);
+
+class CopyFromHostToGpuKernel : public AsyncOpKernel {
+ public:
+  explicit CopyFromHostToGpuKernel(OpKernelConstruction* context)
+      : AsyncOpKernel(context) {}
+  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+    const Tensor& input = ctx->input(0);
+    OP_REQUIRES_ASYNC(
+        ctx, ctx->input_alloc_attr(0).on_host(),
+        errors::Internal("The input tensor to the _CopyFromHostToGpu kernel "
+                         "must reside on the host."),
+        done);
+
+    Tensor* output;
+    OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, input.shape(), &output),
+                         done);
+
+    ctx->op_device_context()->CopyCPUTensorToDevice(
+        &input, static_cast<Device*>(ctx->device()), output,
+        [ctx, done](const Status& s) {
+          ctx->SetStatus(s);
+          done();
+        });
+  }
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("_CopyFromHostToGpu").Device(DEVICE_GPU).HostMemory("input"),
+    CopyFromHostToGpuKernel);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc b/tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc
new file mode 100644 (file)
index 0000000..4682834
--- /dev/null
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Definition for the ops used to swap data in and out of GPU memory.
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace {
+
+// The _CopyFromGpuToHost op copies its input tensor to the host. The input must
+// reside on GPU. The op itself must be placed on GPU.
+REGISTER_OP("_CopyFromGpuToHost")
+    .Input("input: T")
+    .Output("output: T")
+    .Attr("T: type")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      c->set_output(0, c->input(0));
+      auto* handle_data = c->input_handle_shapes_and_types(0);
+      if (handle_data != nullptr) {
+        c->set_output_handle_shapes_and_types(0, *handle_data);
+      }
+      return Status::OK();
+    })
+    .Doc("Copies the input tensor from gpu to the host.");
+
+// The _CopyFromHostToGpu op copies its input tensor from the host to the GPU.
+// The input must reside on CPU. The op itself must be placed on GPU.
+REGISTER_OP("_CopyFromHostToGpu")
+    .Input("input: T")
+    .Output("output: T")
+    .Attr("T: type")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      c->set_output(0, c->input(0));
+      auto* handle_data = c->input_handle_shapes_and_types(0);
+      if (handle_data != nullptr) {
+        c->set_output_handle_shapes_and_types(0, *handle_data);
+      }
+      return Status::OK();
+    })
+    .Doc("Copies the input tensor from the host to the GPU.");
+
+}  // namespace
+}  // namespace tensorflow
index dec4f04..694139f 100644 (file)
@@ -720,18 +720,19 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap,
   // Force the tensor to be copied to cpu.
   NodeDef* swap_out_node = graph->add_node();
   swap_out_node->set_name(swap_out_name);
-  swap_out_node->set_op("Identity");
-  swap_out_node->set_device("/device:CPU:0");
+  swap_out_node->set_op("_CopyFromGpuToHost");
 
   // Force the tensor to be restored to the device.
   NodeDef* swap_in_node = graph->add_node();
   swap_in_node->set_name(swap_in_name);
-  swap_in_node->set_op("Identity");
+  swap_in_node->set_op("_CopyFromHostToGpu");
   *swap_in_node->add_input() = swap_out_node->name();
 
-  // Colocate the swap_in_ node with the node itself.
+  // Colocate the swap_out_ and swap_in_ nodes with the node itself.
+  swap_out_node->set_device(node->device());
   swap_in_node->set_device(node->device());
   string coloc_group = strings::StrCat("loc@", tensor_to_swap);
+  (*swap_out_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
   (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
   (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
 
index 5d7913e..9595936 100644 (file)
@@ -221,16 +221,20 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
   // Build a simple graph with an op that's marked for swapping.
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
 
-  Output a = ops::Variable(s.WithOpName("a"), {10, 10}, DT_FLOAT);
-  Output b = ops::AddN(s.WithOpName("b"), {a});
-  Output c = ops::AddN(s.WithOpName("c"), {b});
-  Output d = ops::AddN(s.WithOpName("d"), {c});
-  Output e = ops::AddN(s.WithOpName("e"), {b, d});
+  Output a =
+      ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"), {10, 10}, DT_FLOAT);
+  Output b = ops::AddN(s.WithOpName("b").WithDevice("/gpu:0"), {a});
+  Output c = ops::AddN(s.WithOpName("c").WithDevice("/gpu:0"), {b});
+  Output d = ops::AddN(s.WithOpName("d").WithDevice("/gpu:0"), {c});
+  Output e = ops::AddN(s.WithOpName("e").WithDevice("/gpu:0"), {b, d});
+
+  Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {10, 10});
+  Output init = ops::Assign(s.WithOpName("init"), a, constant);
 
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
 
-  EXPECT_EQ(5, item.graph.node_size());
+  EXPECT_EQ(7, item.graph.node_size());
   EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name());
   AttrValue& val =
       (*item.graph.mutable_node(4)->mutable_attr())["_swap_to_host"];
@@ -243,32 +247,43 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
   Status status = optimizer.Optimize(cluster.get(), item, &output);
   TF_EXPECT_OK(status);
 
-  EXPECT_EQ(7, output.node_size());
-  const NodeDef& new_e = output.node(4);
+  EXPECT_EQ(9, output.node_size());
+  const NodeDef& new_e = output.node(6);
   EXPECT_EQ(NodeName(e.name()), new_e.name());
 
   EXPECT_EQ(2, new_e.input_size());
   EXPECT_EQ(NodeName(d.name()), new_e.input(1));
   EXPECT_EQ("swap_in_e_0", new_e.input(0));
 
-  const NodeDef& swap_out = output.node(5);
+  const NodeDef& swap_out = output.node(7);
   EXPECT_EQ("swap_out_e_0", swap_out.name());
+  EXPECT_EQ("_CopyFromGpuToHost", swap_out.op());
 
-  const NodeDef& swap_in = output.node(6);
+  const NodeDef& swap_in = output.node(8);
   EXPECT_EQ("swap_in_e_0", swap_in.name());
+  EXPECT_EQ("_CopyFromHostToGpu", swap_in.op());
 
   EXPECT_EQ(NodeName(b.name()), swap_out.input(0));
   EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0));
   EXPECT_EQ("^c", swap_in.input(1));
 
-  const NodeDef& new_c = output.node(2);
+  const NodeDef& new_c = output.node(4);
   EXPECT_EQ(NodeName(c.name()), new_c.name());
   EXPECT_EQ("^swap_out_e_0", new_c.input(1));
 
   // Run the optimizer a second time to ensure it's idempotent.
-  item.graph.Swap(&output);
-  status = optimizer.Optimize(cluster.get(), item, &output);
+  GrapplerItem item_copy(item, std::move(output));
+  status = optimizer.Optimize(cluster.get(), item_copy, &output);
   TF_EXPECT_OK(status);
+
+#if GOOGLE_CUDA
+  item.fetch = {"e"};
+  item.init_ops = {init.name()};
+  auto tensors_expected = EvaluateFetchNodes(item);
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+  test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+#endif
 }
 
 TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
@@ -287,9 +302,13 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
   Output h = ops::Exp(s.WithOpName("h").WithDevice("/gpu:0"), c);
   Output i = ops::Log(s.WithOpName("i").WithDevice("/gpu:0"), d);
 
+  Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {128, 128, 8});
+  Output init = ops::Assign(s.WithOpName("init"), v, constant);
+
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   item.fetch = {"e", "f", "g", "h", "i"};
+  item.init_ops = {init.name()};
 
   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
 
@@ -308,6 +327,15 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
       EXPECT_EQ("axis", node.input(4));
     }
   }
+
+#if GOOGLE_CUDA
+  auto tensors_expected = EvaluateFetchNodes(item);
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+  for (int i = 0; i < item.fetch.size(); ++i) {
+    test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
+  }
+#endif
 }
 
 TEST_F(MemoryOptimizerTest, UnswappableInputs) {
@@ -325,9 +353,13 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) {
   Output e =
       ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
 
+  Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {128, 128, 8});
+  Output init = ops::Assign(s.WithOpName("init"), v, constant);
+
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   item.fetch = {"e"};
+  item.init_ops = {init.name()};
 
   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
 
@@ -344,6 +376,13 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) {
       EXPECT_EQ("^swap_out_d_2", node.input(4));
     }
   }
+
+#if GOOGLE_CUDA
+  auto tensors_expected = EvaluateFetchNodes(item);
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+  test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+#endif
 }
 
 TEST_F(MemoryOptimizerTest, AccumulationRewrites) {
index 0a9dbe2..5d32609 100644 (file)
@@ -142,6 +142,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
+        "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler:utils",
     ],
 )
index fed46c0..fef8e97 100644 (file)
@@ -35,6 +35,23 @@ std::vector<Tensor> GrapplerTest::EvaluateNodes(
   return output_tensors;
 }
 
+std::vector<Tensor> GrapplerTest::EvaluateFetchNodes(const GrapplerItem& item) {
+  SessionOptions options;
+  std::unique_ptr<tensorflow::Session> session(NewSession(options));
+  TF_CHECK_OK(session->Create(item.graph));
+  RunOptions run_options;
+  if (!item.init_ops.empty()) {
+    std::vector<Tensor> dummy;
+    TF_CHECK_OK(
+        session->Run(run_options, {}, {}, item.init_ops, &dummy, nullptr));
+  }
+  std::vector<Tensor> output_tensors;
+  TF_CHECK_OK(
+      session->Run(run_options, {}, item.fetch, {}, &output_tensors, nullptr));
+  TF_CHECK_OK(session->Close());
+  return output_tensors;
+}
+
 void GrapplerTest::AddNode(const string& name, const string& op,
                            const std::vector<string>& inputs, GraphDef* graph) {
   auto* node = graph->add_node();
index 042b616..fd6809b 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -30,6 +31,8 @@ class GrapplerTest : public ::testing::Test {
   std::vector<Tensor> EvaluateNodes(const GraphDef& graph,
                                     const std::vector<string>& node_names);
 
+  std::vector<Tensor> EvaluateFetchNodes(const GrapplerItem& item);
+
   void AddNode(const string& name, const string& op,
                const std::vector<string>& inputs, GraphDef* graph);