[tf.data] Move the C++ Dataset class implementations to the framework library.
authorDerek Murray <mrry@google.com>
Wed, 7 Feb 2018 21:44:30 +0000 (13:44 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Feb 2018 21:48:12 +0000 (13:48 -0800)
This enables the use of the `DatasetOpKernel` subclasses in custom op library
code. A subsequent change will move `tf.contrib.data` kernel implementations
to a custom op library.

Implementation note: This change moves some classes from
"tensorflow/core/graph/..." into the framework library, which does not
include any code in "tensorflow/core/common_runtime/...". To break the
dependency from "tensorflow/core/framework/dataset.cc" to
"tensorflow/core/common_runtime/...", the `GraphDefBuilderToGraph()`
method has been split out from the `GraphDefBuilder` class (where it
was previously exposed as the `GraphDefBuilder::ToGraph()` utility
method) and added to a new
"tensorflow/core/graph/graph_def_builder_util.h" module.  This method
depends on ".../graph/graph_constructor.cc", which depends directly on
".../common_runtime/shape_refiner.h" and indirectly on
".../common_runtime/graph_runner.h". Since this method was used only
in tests, these have been updated to point to the new utility method.

PiperOrigin-RevId: 184888903

16 files changed:
tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
tensorflow/contrib/cmake/tf_core_cpu.cmake
tensorflow/contrib/cmake/tf_core_framework.cmake
tensorflow/core/BUILD
tensorflow/core/common_runtime/placer_test.cc
tensorflow/core/framework/dataset.cc [moved from tensorflow/core/kernels/data/dataset.cc with 97% similarity]
tensorflow/core/framework/dataset.h
tensorflow/core/graph/algorithm_test.cc
tensorflow/core/graph/graph_def_builder.cc
tensorflow/core/graph/graph_def_builder.h
tensorflow/core/graph/graph_def_builder_test.cc
tensorflow/core/graph/graph_def_builder_util.cc [new file with mode: 0644]
tensorflow/core/graph/graph_def_builder_util.h [new file with mode: 0644]
tensorflow/core/graph/subgraph_test.cc
tensorflow/core/kernels/data/BUILD
tensorflow/core/kernels/data/iterator_ops.cc

index 454f0ae..1a8858c 100644 (file)
@@ -25,6 +25,7 @@ limitations under the License.
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -80,7 +81,7 @@ TEST(XlaCompilationTest, Chains) {
         ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
     Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
     ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -105,7 +106,7 @@ TEST(XlaCompilationTest, UncompilableCycles) {
     Node* b =
         ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -125,7 +126,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
                                          .WithAttr("value", Tensor()));
     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -148,7 +149,7 @@ TEST(XlaCompilationTest, UnsupportedTypes) {
                      .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape())));
     Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -177,7 +178,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
     concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
     builder.opts().FinalizeBuilder(&concat_builder);
 
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -212,7 +213,7 @@ TEST(XlaCompilationTest, FunctionCalls) {
     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
     ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
     ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def));
@@ -244,7 +245,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
     Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
     Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
     ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
   TF_ASSERT_OK(MarkForCompilation(&graph));
   auto clusters = GetClusters(*graph);
@@ -330,7 +331,7 @@ TEST(XlaCompilationTest, SymbolicGradients) {
     d_builder.Input({c, c});
     builder.opts().FinalizeBuilder(&d_builder);
 
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -382,7 +383,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
     ops::BinaryOp(
         "MatMul", a, b,
         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
-    TF_CHECK_OK(builder.ToGraph(graph.get()));
+    TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -413,7 +414,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
     ops::BinaryOp(
         "Add", b, c,
         builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2"));
-    TF_CHECK_OK(builder.ToGraph(graph.get()));
+    TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -443,7 +444,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
         "Relu", a,
         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
-    TF_CHECK_OK(builder.ToGraph(graph.get()));
+    TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
@@ -484,7 +485,7 @@ TEST(XlaCompilationTest, Resources) {
     Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C"));
     Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D"));
     ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
   TF_ASSERT_OK(MarkForCompilation(&graph));
   auto clusters = GetClusters(*graph);
@@ -541,7 +542,7 @@ TEST(XlaCompilationTest, Retval) {
                      .WithAttr("T", DT_FLOAT)
                      .WithAttr("index", 0));
 
-    TF_EXPECT_OK(builder.ToGraph(graph.get()));
+    TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
   }
 
   TF_ASSERT_OK(MarkForCompilation(&graph));
index e4213ea..96ac60d 100644 (file)
@@ -50,6 +50,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs
     "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
     "${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
     "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
+    "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.h"
+    "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.cc"
+    "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.h"
+    "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.cc"
+    "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.h"
+    "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.cc"
     "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h"
     "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc"
     "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h"
index 129c208..a1c3203 100644 (file)
@@ -292,6 +292,12 @@ file(GLOB_RECURSE tf_core_framework_srcs
     "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
     "${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
     "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
+    "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.h"
+    "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.cc"
+    "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.h"
+    "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.cc"
+    "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.h"
+    "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.cc"
     "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h"
     "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc"
     "${tensorflow_source_dir}/tensorflow/core/util/*.h"
index 7fade69..c25aac3 100644 (file)
@@ -784,6 +784,7 @@ tf_cuda_library(
         "graph/graph.h",
         "graph/graph_constructor.h",
         "graph/graph_def_builder.h",
+        "graph/graph_def_builder_util.h",
         "graph/node_builder.h",
         "graph/validate.h",
         "graph/while_context.h",
@@ -1718,6 +1719,9 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [
     "platform/variant_coding.h",
     "graph/edgeset.h",
     "graph/graph.h",
+    "graph/graph_def_builder.h",
+    "graph/node_builder.h",
+    "graph/tensor_id.h",
 ] + glob(
     [
         "example/**/*.h",
@@ -1804,6 +1808,9 @@ tf_cuda_library(
         ] + [
             "graph/edgeset.cc",
             "graph/graph.cc",
+            "graph/graph_def_builder.cc",
+            "graph/node_builder.cc",
+            "graph/tensor_id.cc",
             "graph/while_context.h",
             "graph/while_context.cc",
         ],
@@ -1932,6 +1939,7 @@ GRAPH_HDRS = [
     "graph/graph.h",
     "graph/graph_constructor.h",  # NOTE(mrry): Don't include the .cc since it depends on common_runtime.
     "graph/graph_def_builder.h",
+    "graph/graph_def_builder_util.h",
     "graph/graph_partition.h",
     "graph/mkl_layout_pass.h",
     "graph/mkl_tfconversion_pass.h",
@@ -1952,12 +1960,9 @@ tf_cuda_library(
         "graph/colors.cc",
         "graph/control_flow.cc",
         "graph/costmodel.cc",
-        "graph/graph_def_builder.cc",
         "graph/graph_partition.cc",
-        "graph/node_builder.cc",
         "graph/optimizer_cse.cc",
         "graph/subgraph.cc",
-        "graph/tensor_id.cc",
         "graph/validate.cc",
     ],
     hdrs = GRAPH_HDRS,
@@ -1986,6 +1991,7 @@ tf_cuda_library(
         "common_runtime/shape_refiner.h",
         "framework/versions.h",
         "graph/graph_constructor.cc",  # Depends on common_runtime.
+        "graph/graph_def_builder_util.cc",  # Depends on common_runtime.
         "public/session.h",
         "public/session_options.h",
         "public/version.h",
index 02c9cd5..098024d 100644 (file)
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
 #include "tensorflow/core/lib/core/error_codes.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -193,7 +194,7 @@ class PlacerTest : public ::testing::Test {
   // Builds the given graph, and (if successful) indexes the node
   // names for use in placement, and later lookup.
   Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) {
-    TF_RETURN_IF_ERROR(builder.ToGraph(out_graph));
+    TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph));
     nodes_by_name_.clear();
     for (Node* node : out_graph->nodes()) {
       nodes_by_name_[node->name()] = node->id();
similarity index 97%
rename from tensorflow/core/kernels/data/dataset.cc
rename to tensorflow/core/framework/dataset.cc
index d18cb16..4145ef7 100644 (file)
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/dataset.h"
+
 #include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/graph/node_builder.h"
 
@@ -265,10 +265,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
   MakeDataset(ctx, input, another_input, output);
 }
 
-Allocator* IteratorContext::allocator(AllocatorAttributes attrs) {
-  return params_.lib->device()->GetAllocator(attrs);
-}
-
 const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
 const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
     "_DATASET_GRAPH_OUTPUT_NODE";
index 96566c2..6ab23d9 100644 (file)
@@ -274,7 +274,7 @@ class IteratorContext {
     std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr;
 
     // The Allocator to be used to allocate the output of an iterator.
-    Allocator* allocator = nullptr;
+    std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
   };
 
   explicit IteratorContext(Params params) : params_(std::move(params)) {}
@@ -301,7 +301,9 @@ class IteratorContext {
 
   void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; }
 
-  Allocator* allocator(AllocatorAttributes attrs);
+  Allocator* allocator(AllocatorAttributes attrs) {
+    return params_.allocator_getter(attrs);
+  }
 
  private:
   Params params_;
index 0cdcdb6..99ced0c 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
 #include "tensorflow/core/graph/subgraph.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -81,7 +82,7 @@ TEST(AlgorithmTest, ReversePostOrder) {
   BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
 
   Graph g(OpRegistry::Global());
-  TF_ASSERT_OK(b.ToGraph(&g));
+  TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
   std::vector<Node*> order;
 
   // Test reverse post order:
@@ -139,7 +140,7 @@ TEST(AlgorithmTest, ReversePostOrderStable) {
     BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3"));
 
     Graph g(OpRegistry::Global());
-    TF_ASSERT_OK(b.ToGraph(&g));
+    TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
     std::vector<Node*> order;
 
     // Test reverse post order generates expected ordering.
index 33d2021..7a58347 100644 (file)
@@ -17,7 +17,6 @@ limitations under the License.
 
 #include <utility>
 
-#include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/tensor_id.h"
 #include "tensorflow/core/lib/core/errors.h"
 
@@ -72,16 +71,6 @@ Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
   return status_;
 }
 
-Status GraphDefBuilder::ToGraph(Graph* graph) const {
-  if (status_.ok()) {
-    GraphDef graph_def;
-    graph_.ToGraphDef(&graph_def);
-    GraphConstructorOptions opts;
-    TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph));
-  }
-  return status_;
-}
-
 string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
   if (name_.empty()) return graph_->NewName(op);
   return name_;
index a2c0c4d..776a74c 100644 (file)
@@ -161,14 +161,6 @@ class GraphDefBuilder {
   // successful, and if so fill *graph_def.
   Status ToGraphDef(GraphDef* graph_def) const;
 
-  // Like ToGraphDef(), but converts to a Graph (using the default
-  // GraphConstructorOptions).
-  // TODO(josh11b): Make this faster; right now it converts
-  // Graph->GraphDef->Graph.  This cleans up the graph (e.g. adds
-  // edges from the source and to the sink node, resolves back edges
-  // by name), and makes sure the resulting graph is valid.
-  Status ToGraph(Graph* graph) const;
-
   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
   // registry. Ignores duplicate functions, and returns a bad status if an
   // imported function differs from an existing function or op with the same
index e928c81..be3c2be 100644 (file)
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include "tensorflow/core/framework/versions.pb.h"
 #include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
@@ -34,7 +35,7 @@ TEST(GraphDefBuilderTest, Version) {
 
   // Check version when we convert to a Graph
   Graph graph(OpRegistry::Global());
-  TF_EXPECT_OK(builder.ToGraph(&graph));
+  TF_EXPECT_OK(GraphDefBuilderToGraph(builder, &graph));
   ASSERT_EQ(graph.versions().producer(), TF_GRAPH_DEF_VERSION);
   ASSERT_EQ(graph.versions().min_consumer(), TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
 
diff --git a/tensorflow/core/graph/graph_def_builder_util.cc b/tensorflow/core/graph/graph_def_builder_util.cc
new file mode 100644 (file)
index 0000000..102c721
--- /dev/null
@@ -0,0 +1,28 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+
+#include "tensorflow/core/graph/graph_constructor.h"
+
+namespace tensorflow {
+
+Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph) {
+  GraphDef graph_def;
+  TF_RETURN_IF_ERROR(builder.ToGraphDef(&graph_def));
+  GraphConstructorOptions opts;
+  return ConvertGraphDefToGraph(opts, graph_def, graph);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_def_builder_util.h b/tensorflow/core/graph/graph_def_builder_util.h
new file mode 100644 (file)
index 0000000..4a157e5
--- /dev/null
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_
+
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class Graph;
+
+// Converts the `GraphDef` being built by `builder` to a `Graph` and
+// stores it in `*graph`.
+// TODO(josh11b): Make this faster; right now it converts
+// Graph->GraphDef->Graph.  This cleans up the graph (e.g. adds
+// edges from the source and to the sink node, resolves back edges
+// by name), and makes sure the resulting graph is valid.
+Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_
index fde1ea1..7219d98 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -361,7 +362,7 @@ static void BM_SubgraphHelper(int iters, int num_nodes,
         last_node = ops::SourceOp("In", b.opts().WithName(name));
       }
     }
-    TF_CHECK_OK(b.ToGraph(&g));
+    TF_CHECK_OK(GraphDefBuilderToGraph(b, &g));
   }
 
   std::vector<string> fed;
index c4e2125..8e91baa 100644 (file)
@@ -44,9 +44,10 @@ tf_kernel_library(
     ],
 )
 
+# TODO(mrry): Remove this empty forwarding library.
 cc_library(
     name = "dataset",
-    srcs = ["dataset.cc"],
+    srcs = [],
     hdrs = ["dataset.h"],
     deps = [
         "//tensorflow/core:core_cpu",
index fc3e291..d7d4ad5 100644 (file)
@@ -160,6 +160,10 @@ class IteratorResource : public ResourceBase {
       params.runner = *(ctx->runner());
       params.function_library = flib_def;
       params.lib = lib_;
+      DeviceBase* device = lib_->device();
+      params.allocator_getter = [device](AllocatorAttributes attrs) {
+        return device->GetAllocator(attrs);
+      };
       IteratorContext iter_ctx(std::move(params));
 
       TF_RETURN_IF_ERROR(captured_iterator->Restore(&iter_ctx, reader));
@@ -605,6 +609,11 @@ class ToSingleElementOp : public AsyncOpKernel {
       params.env = ctx->env();
       params.runner = *(ctx->runner());
       params.lib = ctx->function_library();
+      DeviceBase* device = ctx->function_library()->device();
+      params.allocator_getter = [device](AllocatorAttributes attrs) {
+        return device->GetAllocator(attrs);
+      };
+
       IteratorContext iter_ctx(std::move(params));
 
       std::vector<Tensor> components;
@@ -863,6 +872,10 @@ class IteratorGetNextOp : public AsyncOpKernel {
           };
           params.runner = *(ctx->runner());
           params.function_library = iterator->function_library();
+          DeviceBase* device = ctx->function_library()->device();
+          params.allocator_getter = [device](AllocatorAttributes attrs) {
+            return device->GetAllocator(attrs);
+          };
           IteratorContext iter_ctx(std::move(params));
 
           OP_REQUIRES_OK_ASYNC(
@@ -905,6 +918,10 @@ class IteratorGetNextSyncOp : public OpKernel {
     };
     params.runner = *(ctx->runner());
     params.function_library = iterator->function_library();
+    DeviceBase* device = ctx->function_library()->device();
+    params.allocator_getter = [device](AllocatorAttributes attrs) {
+      return device->GetAllocator(attrs);
+    };
     IteratorContext iter_ctx(std::move(params));
 
     OP_REQUIRES_OK(ctx,