}
MemoryType MTypeFromDType(const DataType dtype) {
- return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY;
+ return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY
+ : DEVICE_MEMORY;
}
} // namespace
"HostMemory args '", str_util::Join(host_memory_args, "', '"),
"' not found in OpDef: ", SummarizeOpDef(*op_def));
}
+ CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
+ CHECK_LE(out_mtypes->size(), out_dtypes.size());
+
+ // Mark e.g. all resource and string types as host memory.
+ for (int i = 0; i < inp_mtypes->size(); ++i) {
+ if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
+ (*inp_mtypes)[i] = HOST_MEMORY;
+ }
+ }
+ for (int i = 0; i < out_mtypes->size(); ++i) {
+ if (DataTypeAlwaysOnHost(out_dtypes[i])) {
+ (*out_mtypes)[i] = HOST_MEMORY;
+ }
+ }
std::vector<int32> hostmem_attr;
if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) {
.Input("b: T")
.Input("c: N * string")
.Input("d: Tlist")
+ .Input("e: Rlist")
.Output("o: N * T")
.Output("p: Tlist")
.Attr("T: type")
.Attr("N: int")
- .Attr("Tlist: list(type)");
+ .Attr("Tlist: list(type)")
+ .Attr("Rlist: list(type)");
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
.Device(DEVICE_GPU)
.Input(FakeInput(DT_BOOL))
.Input(FakeInput(3))
.Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
+ .Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
.Finalize(&node_def));
MemoryTypeVector input, output;
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
&input, &output));
- EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
- DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
- DEVICE_MEMORY, DEVICE_MEMORY}),
- input);
+ // a:float, b:bool, c:3*string, d:(int32, float, int32),
+ // e:(resource, string, resource)
+ EXPECT_EQ(
+ MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
+ DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+ input);
+ // o:3*bool, p:(int32, float, int32)
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
output);
&input, &output));
EXPECT_EQ(
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
- HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
+ HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
input);
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),