[TF] Mark DT_STRING and DT_RESOURCE types as always sitting on host memory.
authorEugene Brevdo <ebrevdo@google.com>
Wed, 13 Dec 2017 01:01:02 +0000 (17:01 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 13 Dec 2017 01:05:03 +0000 (17:05 -0800)
This is important when these arguments may appear in op input lists or output lists,
where the signature may not be able to declare them as sitting on host.

For DT_RESOURCE types, just the handles are marked as sitting on host memory;
the actual data may reside on GPU.

PiperOrigin-RevId: 178837213

tensorflow/core/framework/memory_types.cc
tensorflow/core/framework/memory_types_test.cc
tensorflow/core/framework/types.cc
tensorflow/core/framework/types.h

index 6a2eed94b94971d20faffa1608627290c1109d66..270118bb678e110269be9aa67a3904e36c34c512 100644 (file)
@@ -61,7 +61,8 @@ void MemoryTypesHelper(const NameRangeMap& name_map,
 }
 
 MemoryType MTypeFromDType(const DataType dtype) {
-  return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY;
+  return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY
+                                                            : DEVICE_MEMORY;
 }
 
 }  // namespace
@@ -118,6 +119,20 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
         "HostMemory args '", str_util::Join(host_memory_args, "', '"),
         "' not found in OpDef: ", SummarizeOpDef(*op_def));
   }
+  CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
+  CHECK_LE(out_mtypes->size(), out_dtypes.size());
+
+  // Mark e.g. all resource and string types as host memory.
+  for (int i = 0; i < inp_mtypes->size(); ++i) {
+    if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
+      (*inp_mtypes)[i] = HOST_MEMORY;
+    }
+  }
+  for (int i = 0; i < out_mtypes->size(); ++i) {
+    if (DataTypeAlwaysOnHost(out_dtypes[i])) {
+      (*out_mtypes)[i] = HOST_MEMORY;
+    }
+  }
 
   std::vector<int32> hostmem_attr;
   if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) {
index 4704da9a119c2b06db5c8b1a3874417a0b1c3617..3126ea8e5f8974cb11f88301de613eb5b920830f 100644 (file)
@@ -36,11 +36,13 @@ REGISTER_OP("HostMemoryTest")
     .Input("b: T")
     .Input("c: N * string")
     .Input("d: Tlist")
+    .Input("e: Rlist")
     .Output("o: N * T")
     .Output("p: Tlist")
     .Attr("T: type")
     .Attr("N: int")
-    .Attr("Tlist: list(type)");
+    .Attr("Tlist: list(type)")
+    .Attr("Rlist: list(type)");
 REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
 REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
                             .Device(DEVICE_GPU)
@@ -57,15 +59,20 @@ TEST(MemoryTypesForNode, Simple) {
                    .Input(FakeInput(DT_BOOL))
                    .Input(FakeInput(3))
                    .Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
+                   .Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
                    .Finalize(&node_def));
   MemoryTypeVector input, output;
 
   TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
                                   &input, &output));
-  EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
-                              DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
-                              DEVICE_MEMORY, DEVICE_MEMORY}),
-            input);
+  // a:float, b:bool, c:3*string, d:(int32, float, int32),
+  // e:(resource, string, resource)
+  EXPECT_EQ(
+      MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
+                        HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
+                        DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+      input);
+  // o:3*bool, p:(int32, float, int32)
   EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
                               DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
             output);
@@ -74,7 +81,8 @@ TEST(MemoryTypesForNode, Simple) {
                                   &input, &output));
   EXPECT_EQ(
       MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
-                        HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+                        HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
+                        HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
       input);
   EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
                               DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
index b082dfbd031cde572ed255a19a767c855cc56611..58354d6f4edea1f29ba033f2579324d400a532ab 100644 (file)
@@ -306,6 +306,18 @@ bool DataTypeCanUseMemcpy(DataType dt) {
   }
 }
 
+bool DataTypeAlwaysOnHost(DataType dt) {
+  // Includes DT_STRING and DT_RESOURCE.
+  switch (dt) {
+    case DT_STRING:
+    case DT_STRING_REF:
+    case DT_RESOURCE:
+      return true;
+    default:
+      return false;
+  }
+}
+
 bool DataTypeIsFloating(DataType dt) {
   switch (dt) {
     case DT_HALF:
index 652985658a20b094ac582466972a62b9f1e287a2..27005c0e93267ff4f91d470a011be6d673fe8cc2 100644 (file)
@@ -239,6 +239,11 @@ bool DataTypeIsUnsigned(DataType dt);
 // Returns a 0 on failure
 int DataTypeSize(DataType dt);
 
+// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE.
+// For DT_RESOURCE, the handle always sits on host (even if the underlying
+// object has device-allocated resources).
+bool DataTypeAlwaysOnHost(DataType dt);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_FRAMEWORK_TYPES_H_