From 6d793e177ce377d52772574a3eb90af88e780f97 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Apr 2018 12:46:46 -0700 Subject: [PATCH] Replace GrapplerFunctionItem input with a constant. PiperOrigin-RevId: 194579253 --- tensorflow/core/grappler/utils/functions.cc | 63 +++++++++++++++++++- tensorflow/core/grappler/utils/functions.h | 9 ++- tensorflow/core/grappler/utils/functions_test.cc | 75 ++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index 790809b..79b823f 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -566,6 +566,60 @@ Status RegisterGrapplerFunctionConnectivity( return Status::OK(); } +Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, + GrapplerFunctionItem* item) { + if (!IsConstant(input_const)) { + return errors::InvalidArgument("Input node ", input_const.name(), + " is not a constant"); + } + + auto& inputs = item->input_arg_expansions_; + + // Find input arg expansion and input placeholder position in it for the + // given function input position. + InputArgExpansion* input_arg_expansion = nullptr; + int placeholder_idx = input_position; + + for (InputArgExpansion& input : inputs) { + if (placeholder_idx < input.placeholders.size()) { + input_arg_expansion = &input; + break; + } + placeholder_idx -= input.placeholders.size(); + } + + if (input_arg_expansion == nullptr) { + return errors::InvalidArgument( + "Input placeholder not found: input_position=", input_position, + " function=", item->id); + } + + // Delete placeholder from input expansion. + string placeholder_name = input_arg_expansion->placeholders[placeholder_idx]; + item->input_arg_placeholders_.erase(placeholder_name); + input_arg_expansion->placeholders.erase( + input_arg_expansion->placeholders.begin() + placeholder_idx); + + // Delete empty input expansions. + inputs.erase(std::remove_if(inputs.begin(), inputs.end(), + [](const InputArgExpansion& input) { + return input.placeholders.empty(); + }), + inputs.end()); + + // Replace placeholder node in the function body with a const node. + for (NodeDef& node : *item->graph.mutable_node()) { + if (node.name() == placeholder_name) { + node = input_const; + node.set_name(placeholder_name); + node.clear_input(); // remove potential control inputs + node.clear_device(); // device placement is defined by instantiating node + } + } + + return Status::OK(); +} + Status MakeFunctionDef(const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, FunctionDef* func) { @@ -579,6 +633,9 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item, // Add function input arguments. for (const InputArgExpansion& input_arg : item.inputs()) { + CHECK(input_arg.placeholders.size() == 1) // do some sanity checking + << "Inputs of tensor sequences are not supported"; + OpDef::ArgDef arg_def; arg_def.set_name(input_arg.input_name); arg_def.set_type(input_arg.data_type); @@ -588,15 +645,15 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item, // Add function output arguments. for (const OutputArgExpansion& output_arg : item.outputs()) { + CHECK(output_arg.output_tensors.size() == 1) // do some sanity checking + << "Outputs of tensor sequences are not supported"; + OpDef::ArgDef arg_def; arg_def.set_name(output_arg.output_name); arg_def.set_type(output_arg.data_type); arg_def.set_is_ref(output_arg.is_ref); *func->mutable_signature()->add_output_arg() = arg_def; - CHECK(output_arg.output_tensors.size() == 1) // do some sanity checking - << "Outputs of tensor sequences are not supported"; - string ret; for (const string& output_tensor : output_arg.output_tensors) { TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(output_tensor, &ret)); diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 692333f..d9d71b8 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -162,6 +162,9 @@ class GrapplerFunctionItem : public GrapplerItem { GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other); private: + friend Status ReplaceInputWithConst(const NodeDef&, int, + GrapplerFunctionItem*); + AttrValueMap func_attr_; // Attributes specific to function definition that // produced this item (FuncDef.attr field). @@ -189,12 +192,16 @@ bool HasParametrizedBody(const FunctionDef& func); bool IsParametrized(const FunctionDef& func); // Register GrapplerFunctionItem input arg expansion and function body outputs -// in the GrapplerFunctionConnectivity. Use function library definition to +// in the GrapplerFunctionConnectivity. Use function library definition to // lookup function body nodes output names and ranges. Status RegisterGrapplerFunctionConnectivity( const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, GrapplerFunctionConnectivity* connectivity); +// Replace one of the function inputs with a constant. +Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, + GrapplerFunctionItem* item); + // Make a GrapplerFunctionItem from the function definition and function // instantiation attributes (caller node attributes). Returns error if the given // function def cannot be converted (e.g. not all attributes are defined). diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index 6dfd49b..fa6fec7 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -573,6 +573,81 @@ TEST_F(FunctionsTest, MakeFunctionDef) { EXPECT_EQ(2, count); } +TEST_F(FunctionsTest, ReplaceInputWithConst) { + FunctionDef func = FunctionDefHelper::Create( + "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "output:z:0"}}); + + std::unordered_map func_attr; + func_attr["T"].set_type(DT_FLOAT); + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); + + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + + EXPECT_EQ(2, item.input_size()); + EXPECT_EQ(1, item.output_size()); + + ASSERT_EQ(3, item.function_body().node_size()); + + const NodeDef &input_x = item.function_body().node(0); + const NodeDef &input_y = item.function_body().node(1); + + // Initially inputs added to the graph as placeholders. + EXPECT_EQ("Placeholder", input_x.op()); + EXPECT_EQ("Placeholder", input_y.op()); + + // Replace inputs x and y with constants. + NodeDef const_input_x; + const_input_x.set_op("Const"); + AddNodeAttr("Tag", "const_input_x", &const_input_x); + + NodeDef const_input_y; + const_input_y.set_op("Const"); + AddNodeAttr("Tag", "const_input_y", &const_input_y); + + // Replace input x. + TF_EXPECT_OK(ReplaceInputWithConst(const_input_x, 0, &item)); + + EXPECT_EQ(1, item.input_size()); + EXPECT_EQ("Const", input_x.op()); + EXPECT_EQ("const_input_x", input_x.attr().at("Tag").s()); + + // Replace input y. + TF_EXPECT_OK(ReplaceInputWithConst(const_input_y, 0, &item)); + + EXPECT_EQ(0, item.input_size()); + EXPECT_EQ("Const", input_y.op()); + EXPECT_EQ("const_input_y", input_y.attr().at("Tag").s()); + + // Make a function from const-specialized function item. + FunctionDef specialized; + TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized)); + + EXPECT_EQ(0, specialized.signature().input_arg_size()); + EXPECT_EQ(1, specialized.signature().output_arg_size()); + EXPECT_EQ(3, specialized.node_def_size()); + + // Check that graph has const nodes pushed into function body. + int count = 0; + for (const NodeDef &node : specialized.node_def()) { + if (node.name() == "x" && count++) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("const_input_x", node.attr().at("Tag").s()); + } else if (node.name() == "y" && count++) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("const_input_y", node.attr().at("Tag").s()); + } else if (node.name() == "output" && count++) { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ("x:output:0", node.input(0)); + EXPECT_EQ("y:output:0", node.input(1)); + } + } + EXPECT_EQ(3, count); +} + TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) { using test::function::NDef; -- 2.7.4