namespace tensorflow {
namespace grappler {
+namespace {
+
+class FunctionInliningContext {
+ public:
+ explicit FunctionInliningContext(const GrapplerItem& item)
+ : library_(&item.graph.library()), functions_(InliningCandidates(item)) {}
+
+ const FunctionDefLibrary& Library() const { return *library_; }
+
+ bool HasInlinedFunctions() const { return !functions_.empty(); }
+
+ // Find inlining candidate by name. Return nullptr if not found.
+ const FunctionDef* FindInlinedFunction(const string& name) const {
+ auto it = functions_.find(name);
+ if (it != functions_.end()) {
+ return it->second;
+ } else {
+ return nullptr;
+ }
+ }
+
+ private:
+ std::unordered_map<string, const FunctionDef*> InliningCandidates(
+ const GrapplerItem& item) const {
+ std::unordered_map<string, const FunctionDef*> functions;
+ for (const FunctionDef& func : item.graph.library().function()) {
+ // Don't inline functions marked as noinline
+ if (func.attr().count("_noinline") != 0) {
+ continue;
+ }
+ // Don't touch anything marked XLA to prevent XLA failures further down
+ // the road.
+ if (func.attr().count("_XlaCompile") > 0 &&
+ func.attr().at("_XlaCompile").b()) {
+ continue;
+ }
+ // Can't create IdentityN nodes with no input or output: skip these
+ // functions for now.
+ if (func.signature().input_arg_size() == 0 ||
+ func.signature().output_arg_size() == 0) {
+ continue;
+ }
+ functions[func.signature().name()] = &func;
+ }
+ return functions;
+ }
+
+ const FunctionDefLibrary* library_;
+ std::unordered_map<string, const FunctionDef*> functions_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext);
+};
+
+// Copy input/output argument type to the type_list. Return error if argument
+// type is not explicitly defined, and not specified in function attributes.
+Status CopyArgType(const NodeDef& func_node,
+ const std::unordered_map<string, AttrValue>& func_attr,
+ const string& arg_kind, const OpDef::ArgDef& arg,
+ AttrValue::ListValue* type_list) {
+ if (arg.type() != DT_INVALID) {
+ type_list->add_type(arg.type());
+ } else {
+ auto it = func_attr.find(arg.type_attr());
+ if (it == func_attr.end() || it->second.type() == DT_INVALID) {
+ return errors::InvalidArgument(
+ "Invalid ", arg_kind, " argument ", arg.name(), " for function ",
+ func_node.op(), " instantiated by ", func_node.name());
+ }
+ type_list->add_type(it->second.type());
+ }
+ return Status::OK();
+}
+
+// Add an IdentityN op to hook the function inputs to: this ensures that
+// they're all evaluated before the evaluation of the function body starts.
+Status HookInlinedFunctionInputs(
+ const NodeDef& func_node, const FunctionDef& func,
+ const std::unordered_map<string, AttrValue>& func_attr, NodeDef* inputs) {
+ inputs->set_name(strings::StrCat(func_node.name(), "/", "inlined_inputs"));
+ inputs->set_op("IdentityN");
+ inputs->set_device(func_node.device());
+ *inputs->mutable_input() = func_node.input();
+ AttrValue::ListValue* type_list =
+ (*inputs->mutable_attr())["T"].mutable_list();
+ for (const OpDef::ArgDef& arg : func.signature().input_arg()) {
+ TF_RETURN_IF_ERROR(
+ CopyArgType(func_node, func_attr, "input", arg, type_list));
+ }
+ return Status::OK();
+}
+
+// Add an IdentityN op to hook the function outputs to: this ensures that the
+// function body is fully evaluated before its fanout gets scheduled.
+Status HookInlinedFunctionOutputs(
+ const NodeDef& func_node, const FunctionDef& func,
+ const std::unordered_map<string, AttrValue>& func_attr,
+ const gtl::ArraySlice<string> fetch, NodeDef* outputs) {
+ outputs->set_name(func_node.name());
+ outputs->set_op("IdentityN");
+ outputs->set_device(func_node.device());
+ AttrValue::ListValue* type_list =
+ (*outputs->mutable_attr())["T"].mutable_list();
+ for (int i = 0; i < func.signature().output_arg_size(); ++i) {
+ const OpDef::ArgDef& arg = func.signature().output_arg(i);
+ TF_RETURN_IF_ERROR(
+ CopyArgType(func_node, func_attr, "output", arg, type_list));
+ // Use the fetch names since they take into account the output mapping.
+ outputs->add_input(strings::StrCat(func_node.name(), "/", fetch[i]));
+ }
+ return Status::OK();
+}
+
+Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
+ const FunctionInliningContext& ctx,
+ GraphDef* optimized_graph) {
+ const std::unordered_map<string, AttrValue> func_attr(
+ func_node.attr().begin(), func_node.attr().end());
-Status InlineFunction(const NodeDef& node, const FunctionDef& func,
- const FunctionDefLibrary& library, GraphDef* graph) {
- const std::unordered_map<string, AttrValue> attr(node.attr().begin(),
- node.attr().end());
std::unique_ptr<GrapplerItem> item =
- GrapplerItemFromFunctionDef(func, attr, library);
+ GrapplerItemFromFunctionDef(func, func_attr, ctx.Library());
if (!item) {
- return errors::InvalidArgument("Failed to inline function ", node.op(),
- " instantiated by ", node.name());
+ return errors::InvalidArgument("Failed to inline function ", func_node.op(),
+ " instantiated by ", func_node.name());
}
std::unordered_map<string, int> input_nodes;
input_nodes[arg.name()] = i;
}
- // Add an IdentityN op to hook the function inputs to: this ensures that
- // they're all evaluated before the evaluation of the function body starts.
- NodeDef* func_inputs = graph->add_node();
- func_inputs->set_name(strings::StrCat(node.name(), "/", "inlined_inputs"));
- func_inputs->set_op("IdentityN");
- func_inputs->set_device(node.device());
- *func_inputs->mutable_input() = node.input();
- AttrValue::ListValue* type_list =
- (*func_inputs->mutable_attr())["T"].mutable_list();
- for (const OpDef::ArgDef& arg : func.signature().input_arg()) {
- if (arg.type() != DT_INVALID) {
- type_list->add_type(arg.type());
- } else {
- auto it = attr.find(arg.type_attr());
- if (it == attr.end()) {
- return errors::InvalidArgument("Invalid input argument ", arg.name(),
- " for function ", node.op(),
- " instantiated by ", node.name());
- }
- type_list->add_type(it->second.type());
- }
- }
+ // Hook inlined function inputs to IdentityN node
+ NodeDef* func_inputs = optimized_graph->add_node();
+ TF_RETURN_IF_ERROR(
+ HookInlinedFunctionInputs(func_node, func, func_attr, func_inputs));
for (NodeDef& func_body_node : *item->graph.mutable_node()) {
if (input_nodes.find(func_body_node.name()) != input_nodes.end()) {
+ CHECK_EQ(0, func_body_node.input_size());
// Turn input placeholders into identity nodes
if (IsPlaceholder(func_body_node)) {
func_body_node.set_op("Identity");
}
- CHECK_EQ(0, func_body_node.input_size());
int input_id = input_nodes[func_body_node.name()];
func_body_node.add_input(
strings::StrCat(func_inputs->name(), ":", input_id));
} else {
// Update the input names if any.
for (string& input : *func_body_node.mutable_input()) {
- input = AddPrefixToNodeName(input, node.name());
+ input = AddPrefixToNodeName(input, /*prefix=*/func_node.name());
}
// If the node has no input, make hook it up to the func_inputs node to
// ensure it runs in the same frame as the other nodes of the function
// Add the node name as a prefix to avoid collisions after inlining
func_body_node.set_name(
- strings::StrCat(node.name(), "/", func_body_node.name()));
+ strings::StrCat(func_node.name(), "/", func_body_node.name()));
// Make sure the node is placed
- func_body_node.set_device(node.device());
-
- // Move the node to the main graph
- graph->add_node()->Swap(&func_body_node);
- }
-
- // Add an IdentityN op to hook the function outputs to: this ensures that the
- // function body is fully evaluated before its fanout gets scheduled.
- NodeDef* func_outputs = graph->add_node();
- func_outputs->set_name(node.name());
- func_outputs->set_op("IdentityN");
- func_outputs->set_device(node.device());
- type_list = (*func_outputs->mutable_attr())["T"].mutable_list();
- for (int i = 0; i < func.signature().output_arg_size(); ++i) {
- const OpDef::ArgDef& arg = func.signature().output_arg(i);
- if (arg.type() != DT_INVALID) {
- type_list->add_type(arg.type());
+ func_body_node.set_device(func_node.device());
+
+ // Check if a body node is itself a function
+ const FunctionDef* func_body_node_func =
+ ctx.FindInlinedFunction(func_body_node.op());
+ if (func_body_node_func != nullptr) {
+ // Recursively inline function calls
+ TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
+ ctx, optimized_graph));
} else {
- auto it = attr.find(arg.type_attr());
- if (it == attr.end()) {
- return errors::InvalidArgument("Invalid output argument ", arg.name(),
- " for function ", node.op(),
- " instantiated by ", node.name());
- }
- type_list->add_type(it->second.type());
+ // Move the node to the main graph
+ optimized_graph->add_node()->Swap(&func_body_node);
}
- // Use the fetch names since they take into account the output mapping.
- func_outputs->add_input(strings::StrCat(node.name(), "/", item->fetch[i]));
}
+ // Hook inlined function outputs to IdentityN node
+ NodeDef* func_outputs = optimized_graph->add_node();
+ TF_RETURN_IF_ERROR(HookInlinedFunctionOutputs(func_node, func, func_attr,
+ item->fetch, func_outputs));
+
return Status::OK();
}
return Status::OK();
}
+} // namespace
+
Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- std::unordered_map<string, const FunctionDef*> functions;
- for (const FunctionDef& func : item.graph.library().function()) {
- // Don't inline functions marked as noinline
- if (func.attr().count("_noinline") != 0) {
- continue;
- }
- // Don't touch anything marked XLA to prevent XLA failures further down the
- // road.
- if (func.attr().count("_XlaCompile") > 0 &&
- func.attr().at("_XlaCompile").b()) {
- continue;
- }
- // Can't create IdentityN nodes with no input or output: skip these
- // functions for now.
- if (func.signature().input_arg_size() == 0 ||
- func.signature().output_arg_size() == 0) {
- continue;
- }
- functions[func.signature().name()] = &func;
- }
+ FunctionInliningContext function_inlining_ctx(item);
- // Nothing to do.
- if (functions.empty()) {
+ // Nothing to do here.
+ if (!function_inlining_ctx.HasInlinedFunctions()) {
*optimized_graph = item.graph;
return Status::OK();
}
TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &env, optimized_graph));
continue;
}
- auto it = functions.find(node.op());
- if (it == functions.end()) {
- *optimized_graph->add_node() = node;
+
+ const FunctionDef* func =
+ function_inlining_ctx.FindInlinedFunction(node.op());
+ if (func != nullptr) {
+ TF_RETURN_IF_ERROR(
+ InlineFunction(node, *func, function_inlining_ctx, optimized_graph));
} else {
- TF_RETURN_IF_ERROR(InlineFunction(node, *it->second, item.graph.library(),
- optimized_graph));
+ *optimized_graph->add_node() = node;
}
}
namespace grappler {
namespace {
-class FunctionOptimizerTest : public GrapplerTest {};
+constexpr char kDevice[] = "/device:CPU:0";
+
+class FunctionOptimizerTest : public GrapplerTest {
+ protected:
+ Tensor MakeScalarTensor(float value) {
+ Tensor tensor(DT_FLOAT, {});
+ tensor.scalar<float>()() = value;
+ return tensor;
+ }
+
+ Tensor MakeScalarTensor(int value) {
+ Tensor tensor(DT_INT32, {});
+ tensor.scalar<int>()() = value;
+ return tensor;
+ }
+};
TEST_F(FunctionOptimizerTest, SimpleFunction) {
// Build a graph to compute y = XTimesTwo(x)
}
EXPECT_EQ(7, count);
+ Tensor pi = MakeScalarTensor(3.14f);
item.fetch = {"z"};
- Tensor pi(DT_FLOAT, {});
- pi.flat<float>()(0) = 3.14f;
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
}
EXPECT_EQ(6, count);
+ Tensor pi = MakeScalarTensor(3.14f);
item.fetch = {"z"};
- Tensor pi(DT_FLOAT, {});
- pi.flat<float>()(0) = 3.14f;
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
}
EXPECT_EQ(6, count);
+ Tensor pi = MakeScalarTensor(3.14f);
item.fetch = {"z"};
- Tensor pi(DT_FLOAT, {});
- pi.flat<float>()(0) = 3.14f;
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
TF_EXPECT_OK(status);
item.fetch = {"z0", "z1", "z2"};
- Tensor in(DT_FLOAT, {});
- in.flat<float>()(0) = 3.14f;
- item.feed.emplace_back("x0", in);
- in.flat<float>()(0) = 2.7f;
- item.feed.emplace_back("x1", in);
- in.flat<float>()(0) = 1.0f;
- item.feed.emplace_back("x2", in);
- in.flat<float>()(0) = -1.0f;
- item.feed.emplace_back("x4", in);
- Tensor in_int(DT_INT32, {});
- in_int.flat<int>()(0) = 1234;
- item.feed.emplace_back("x3", in_int);
+ item.feed.emplace_back("x0", MakeScalarTensor(3.14f));
+ item.feed.emplace_back("x1", MakeScalarTensor(2.7f));
+ item.feed.emplace_back("x2", MakeScalarTensor(1.0f));
+ item.feed.emplace_back("x4", MakeScalarTensor(-1.0f));
+ item.feed.emplace_back("x3", MakeScalarTensor(1234));
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
EXPECT_EQ(item.graph.DebugString(), output.DebugString());
}
+TEST_F(FunctionOptimizerTest, InlineFunctionWithNestedFunctionCall) {
+ // Define square via function library:
+ // MySquare(x) = MyMul(x, x)
+
+ FunctionDef mul_func = FunctionDefHelper::Create(
+ "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+
+ FunctionDef square_func = FunctionDefHelper::Create(
+ "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {test::function::NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}},
+ kDevice),
+ test::function::NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}},
+ kDevice),
+ test::function::NDef("outputs", "Identity", {"square:0"},
+ {{"T", DT_FLOAT}}, kDevice)},
+ // FunctionLib
+ {mul_func, square_func});
+
+ GraphDef output;
+ FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "square/inlined_inputs" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("a", node.input(0));
+ } else if (node.name() == "square/x" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/inlined_inputs:0", node.input(0));
+ } else if (node.name() == "square/output/inlined_inputs" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("square/x", node.input(0));
+ EXPECT_EQ("square/x", node.input(1));
+ } else if (node.name() == "square/output/x" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/output/inlined_inputs:0", node.input(0));
+ } else if (node.name() == "square/output/y" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/output/inlined_inputs:1", node.input(0));
+ } else if (node.name() == "square/output/output" && count++) {
+ EXPECT_EQ("Mul", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("square/output/x", node.input(0));
+ EXPECT_EQ("square/output/y", node.input(1));
+ } else if (node.name() == "square/output" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/output/output:0", node.input(0));
+ } else if (node.name() == "square" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/output:0", node.input(0));
+ } else if (node.name() == "outputs" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square:0", node.input(0));
+ }
+ }
+ EXPECT_EQ(9, count);
+
+ item.fetch = {"outputs"};
+ item.feed.emplace_back("a", MakeScalarTensor(2.0f));
+ auto tensors_expected = EvaluateFetchNodes(item);
+
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
TEST_F(FunctionOptimizerTest, SymbolicGradients) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();