return Status::OK();
}
+Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
+ GrapplerFunctionItem* item) {
+ if (!IsConstant(input_const)) {
+ return errors::InvalidArgument("Input node ", input_const.name(),
+ " is not a constant");
+ }
+
+ auto& inputs = item->input_arg_expansions_;
+
+ // Find input arg expansion and input placeholder position in it for the
+ // given function input position.
+ InputArgExpansion* input_arg_expansion = nullptr;
+ int placeholder_idx = input_position;
+
+ for (InputArgExpansion& input : inputs) {
+ if (placeholder_idx < input.placeholders.size()) {
+ input_arg_expansion = &input;
+ break;
+ }
+ placeholder_idx -= input.placeholders.size();
+ }
+
+ if (input_arg_expansion == nullptr) {
+ return errors::InvalidArgument(
+ "Input placeholder not found: input_position=", input_position,
+ " function=", item->id);
+ }
+
+ // Delete placeholder from input expansion.
+ string placeholder_name = input_arg_expansion->placeholders[placeholder_idx];
+ item->input_arg_placeholders_.erase(placeholder_name);
+ input_arg_expansion->placeholders.erase(
+ input_arg_expansion->placeholders.begin() + placeholder_idx);
+
+ // Delete empty input expansions.
+ inputs.erase(std::remove_if(inputs.begin(), inputs.end(),
+ [](const InputArgExpansion& input) {
+ return input.placeholders.empty();
+ }),
+ inputs.end());
+
+ // Replace placeholder node in the function body with a const node.
+ for (NodeDef& node : *item->graph.mutable_node()) {
+ if (node.name() == placeholder_name) {
+ node = input_const;
+ node.set_name(placeholder_name);
+ node.clear_input(); // remove potential control inputs
+ node.clear_device(); // device placement is defined by instantiating node
+ }
+ }
+
+ return Status::OK();
+}
+
Status MakeFunctionDef(const GrapplerFunctionItem& item,
const FunctionLibraryDefinition& flib,
FunctionDef* func) {
// Add function input arguments.
for (const InputArgExpansion& input_arg : item.inputs()) {
+ CHECK(input_arg.placeholders.size() == 1) // do some sanity checking
+ << "Inputs of tensor sequences are not supported";
+
OpDef::ArgDef arg_def;
arg_def.set_name(input_arg.input_name);
arg_def.set_type(input_arg.data_type);
// Add function output arguments.
for (const OutputArgExpansion& output_arg : item.outputs()) {
+ CHECK(output_arg.output_tensors.size() == 1) // do some sanity checking
+ << "Outputs of tensor sequences are not supported";
+
OpDef::ArgDef arg_def;
arg_def.set_name(output_arg.output_name);
arg_def.set_type(output_arg.data_type);
arg_def.set_is_ref(output_arg.is_ref);
*func->mutable_signature()->add_output_arg() = arg_def;
- CHECK(output_arg.output_tensors.size() == 1) // do some sanity checking
- << "Outputs of tensor sequences are not supported";
-
string ret;
for (const string& output_tensor : output_arg.output_tensors) {
TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(output_tensor, &ret));
GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
private:
+ friend Status ReplaceInputWithConst(const NodeDef&, int,
+ GrapplerFunctionItem*);
+
AttrValueMap func_attr_; // Attributes specific to function definition that
// produced this item (FuncDef.attr field).
bool IsParametrized(const FunctionDef& func);
// Register GrapplerFunctionItem input arg expansion and function body outputs
-// in the GrapplerFunctionConnectivity. Use function library definition to
+// in the GrapplerFunctionConnectivity. Use function library definition to
// lookup function body nodes output names and ranges.
Status RegisterGrapplerFunctionConnectivity(
const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
GrapplerFunctionConnectivity* connectivity);
+// Replace one of the function inputs with a constant.
+Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
+ GrapplerFunctionItem* item);
+
// Make a GrapplerFunctionItem from the function definition and function
// instantiation attributes (caller node attributes). Returns error if the given
// function def cannot be converted (e.g. not all attributes are defined).
EXPECT_EQ(2, count);
}
+TEST_F(FunctionsTest, ReplaceInputWithConst) {
+ FunctionDef func = FunctionDefHelper::Create(
+ "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+
+ std::unordered_map<string, AttrValue> func_attr;
+ func_attr["T"].set_type(DT_FLOAT);
+ FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
+
+ GrapplerFunctionItem item;
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+
+ EXPECT_EQ(2, item.input_size());
+ EXPECT_EQ(1, item.output_size());
+
+ ASSERT_EQ(3, item.function_body().node_size());
+
+ const NodeDef &input_x = item.function_body().node(0);
+ const NodeDef &input_y = item.function_body().node(1);
+
+ // Initially inputs added to the graph as placeholders.
+ EXPECT_EQ("Placeholder", input_x.op());
+ EXPECT_EQ("Placeholder", input_y.op());
+
+ // Replace inputs x and y with constants.
+ NodeDef const_input_x;
+ const_input_x.set_op("Const");
+ AddNodeAttr("Tag", "const_input_x", &const_input_x);
+
+ NodeDef const_input_y;
+ const_input_y.set_op("Const");
+ AddNodeAttr("Tag", "const_input_y", &const_input_y);
+
+ // Replace input x.
+ TF_EXPECT_OK(ReplaceInputWithConst(const_input_x, 0, &item));
+
+ EXPECT_EQ(1, item.input_size());
+ EXPECT_EQ("Const", input_x.op());
+ EXPECT_EQ("const_input_x", input_x.attr().at("Tag").s());
+
+ // Replace input y.
+ TF_EXPECT_OK(ReplaceInputWithConst(const_input_y, 0, &item));
+
+ EXPECT_EQ(0, item.input_size());
+ EXPECT_EQ("Const", input_y.op());
+ EXPECT_EQ("const_input_y", input_y.attr().at("Tag").s());
+
+ // Make a function from const-specialized function item.
+ FunctionDef specialized;
+ TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized));
+
+ EXPECT_EQ(0, specialized.signature().input_arg_size());
+ EXPECT_EQ(1, specialized.signature().output_arg_size());
+ EXPECT_EQ(3, specialized.node_def_size());
+
+ // Check that graph has const nodes pushed into function body.
+ int count = 0;
+ for (const NodeDef &node : specialized.node_def()) {
+ if (node.name() == "x" && count++) {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("const_input_x", node.attr().at("Tag").s());
+ } else if (node.name() == "y" && count++) {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("const_input_y", node.attr().at("Tag").s());
+ } else if (node.name() == "output" && count++) {
+ EXPECT_EQ("Mul", node.op());
+ EXPECT_EQ("x:output:0", node.input(0));
+ EXPECT_EQ("y:output:0", node.input(1));
+ }
+ }
+ EXPECT_EQ(3, count);
+}
+
TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
using test::function::NDef;