==============================================================================*/
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
+
#include <unordered_map>
+
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
return unique_name;
}
+// Specialized function instantiation type parameters, body parameters, and
+// const inputs.
+struct FunctionSpecializationSignature {
+ string func_name;
+ std::unordered_map<string, DataType> type_parameters;
+ std::unordered_map<string, AttrValue> body_parameters;
+ std::unordered_map<int, string> const_inputs;
+
+ bool operator==(const FunctionSpecializationSignature& other) const {
+ bool equals = func_name == other.func_name &&
+ type_parameters == other.type_parameters &&
+ const_inputs == other.const_inputs;
+
+ if (!equals) return false;
+
+ // Equality is not defined for AttrValue.
+ if (body_parameters.size() != other.body_parameters.size()) return false;
+
+ for (const auto& lhs : body_parameters) {
+ auto it = other.body_parameters.find(lhs.first);
+ if (it == other.body_parameters.end()) return false;
+ if (!AreAttrValuesEqual(lhs.second, (*it).second)) return false;
+ }
+
+ return true;
+ }
+
+ struct Hash {
+ uint64 operator()(FunctionSpecializationSignature const& s) const {
+ uint64 h = Hash64(s.func_name);
+
+ // Use std::map for deterministic iteration order.
+
+ std::map<string, DataType> types(s.type_parameters.begin(),
+ s.type_parameters.end());
+ for (const auto& pair : types) {
+ AttrValue attr_value;
+ attr_value.set_type(pair.second);
+ h = Hash64Combine(Hash64(pair.first), h);
+ h = Hash64Combine(AttrValueHash(attr_value), h);
+ }
+
+ std::map<string, AttrValue> body(s.body_parameters.begin(),
+ s.body_parameters.end());
+ for (const auto& pair : body) {
+ h = Hash64Combine(Hash64(pair.first), h);
+ h = Hash64Combine(AttrValueHash(pair.second), h);
+ }
+
+ std::map<int, string> inputs(s.const_inputs.begin(),
+ s.const_inputs.end());
+ for (const auto& pair : inputs) {
+ h = Hash64Combine(std::hash<int>()(pair.first), h);
+ h = Hash64Combine(Hash64(pair.second), h);
+ }
+
+ return h;
+ }
+ };
+};
+
+struct FunctionSpecialization {
+ string specialized_func_name;
+ std::unordered_set<string> const_inputs;
+ std::unordered_set<string> control_deps;
+};
+
class FunctionOptimizerContext {
public:
explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level,
return gtl::FindWithDefault(inlined_functions_, name, nullptr);
}
+ const FunctionSpecialization* FindFunctionSpecialization(
+ const FunctionSpecializationSignature& sig) const {
+ return gtl::FindOrNull(specialized_functions_, sig);
+ }
+
+ void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
+ const FunctionSpecialization& specialized_func) {
+ specialized_functions_.emplace(sig, specialized_func);
+ }
+
private:
void InitializeTrulyConstNodes(const GrapplerItem& item) {
std::unordered_set<string> feed_nodes;
// Nodes that are Const and not in feed.
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
+ // Specialized functions.
+ std::unordered_map<FunctionSpecializationSignature,
+ const FunctionSpecialization,
+ FunctionSpecializationSignature::Hash>
+ specialized_functions_;
+
TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
};
for (const string& ctrl : control_deps) {
if (existing_control_deps.find(ctrl) == existing_control_deps.end()) {
- VLOG(3) << "Forward control dependency to function caller node: input="
- << ctrl;
+ VLOG(3) << "Forward control dependency: input=" << ctrl;
specialized_func_node->add_input(ctrl);
}
}
}
}
+Status InitializeFunctionSpecializationSignature(
+ const NodeDef& func_node, const FunctionDef& func,
+ const AttrValueMap& func_attr, const FunctionOptimizerContext& ctx,
+ FunctionSpecializationSignature* sig) {
+ sig->func_name = func.signature().name();
+
+ TF_RETURN_IF_ERROR(
+ InstantiationTypeParameters(func, func_attr, &sig->type_parameters));
+ TF_RETURN_IF_ERROR(
+ InstantiationBodyParameters(func, func_attr, &sig->body_parameters));
+
+ for (int i = 0; i < func_node.input_size(); ++i) {
+ const string& input = func_node.input(i);
+ if (ctx.IsTrulyConst(input)) {
+ sig->const_inputs.emplace(i, input);
+ }
+ }
+
+ return Status::OK();
+}
+
Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
FunctionOptimizerContext* ctx,
GraphDef* optimized_graph) {
const std::unordered_map<string, AttrValue> func_attr(
func_node.attr().begin(), func_node.attr().end());
+ FunctionSpecializationSignature signature;
+ TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
+ func_node, func, func_attr, *ctx, &signature));
+
+ // Check if function was already specialized for identical context.
+ const FunctionSpecialization* already_specialized =
+ ctx->FindFunctionSpecialization(signature);
+
+ if (already_specialized) {
+ VLOG(2) << "Function was already specialized in identical context: "
+ "specialized_name="
+ << already_specialized->specialized_func_name;
+
+ // Add a function call node for the specialized function.
+ NodeDef* specialized_func_node = optimized_graph->add_node();
+ *specialized_func_node = func_node;
+ specialized_func_node->set_op(already_specialized->specialized_func_name);
+
+ RemovePushedDownConstInputs(already_specialized->const_inputs,
+ already_specialized->control_deps,
+ specialized_func_node);
+
+ return Status::OK();
+ }
+
+ // Add a new specialized function definition to the library.
const auto& flib = ctx->function_library();
// Make a GrapplerFunctionItem and convert it back to FunctionDef after
// Update specialized node to remove inputs for pushed down consts.
RemovePushedDownConstInputs(const_inputs, control_deps,
specialized_func_node);
+
+ ctx->AddSpecializedFunction(
+ signature, {specialized_func_name, const_inputs, control_deps});
+
return Status::OK();
}
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
+TEST_F(FunctionOptimizerTest, SpecializeFunction_OncePerUniqueContext) {
+ using test::function::NDef;
+
+ FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
+
+ // Mark MyMul as noinline.
+ FunctionDef mul_func = FunctionDefHelper::Create(
+ "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, int32}"},
+ {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+ (*mul_func.mutable_attr())["_noinline"].set_b(true);
+ std::vector<FunctionDef> function_library = {mul_func};
+
+ const Tensor kTwo = test::AsScalar<float>(2.0);
+ const Tensor kThree = test::AsScalar<float>(3.0);
+
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("init", "NoOp", {}, {}, kDevice),
+
+ // Float placeholders.
+ NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+
+ // Int32 placeholders.
+ NDef("xi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
+ NDef("yi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
+
+ // Consts. Control inputs has to be attached to specialized func calls.
+ NDef("two", "Const", {"^init", "^xf"},
+ {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
+ NDef("three", "Const", {"^init", "^xf"},
+ {{"dtype", DT_FLOAT}, {"value", kThree}}, kDevice),
+
+ // Specialization #1: DT_FLOAT type parameter.
+ NDef("mul_1", "MyMul", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("mul_2", "MyMul", {"yf", "xf"}, {{"T", DT_FLOAT}}, kDevice),
+
+ // Specialization #2: DT_INT32 type parameter.
+ NDef("mul_3", "MyMul", {"xi", "yi"}, {{"T", DT_INT32}}, kDevice),
+
+ // Specialization #3: DT_FLOAT type parameter + const input kTwo.
+ NDef("mul_4", "MyMul", {"xf", "two"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("mul_5", "MyMul", {"yf", "two"}, {{"T", DT_FLOAT}}, kDevice),
+
+ // Specialization #4: DT_FLOAT type parameter + const input kThree.
+ NDef("mul_6", "MyMul", {"three", "xf"}, {{"T", DT_FLOAT}}, kDevice)},
+ function_library);
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ // Make sure that MyMul was specialized once per unique context.
+ EXPECT_EQ(4, output.library().function_size());
+
+ // And graph nodes calling specialized functions.
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "mul_1" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_1", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("xf", node.input(0));
+ EXPECT_EQ("yf", node.input(1));
+
+ } else if (node.name() == "mul_2" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_1", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("yf", node.input(0));
+ EXPECT_EQ("xf", node.input(1));
+
+ } else if (node.name() == "mul_3" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_3", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("xi", node.input(0));
+ EXPECT_EQ("yi", node.input(1));
+
+ } else if (node.name() == "mul_4" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_4", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("xf", node.input(0));
+ EXPECT_EQ("^init", node.input(1));
+
+ } else if (node.name() == "mul_5" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_4", node.op());
+ ASSERT_EQ(3, node.input_size());
+ EXPECT_EQ("yf", node.input(0));
+ EXPECT_EQ("^init", node.input(1));
+ EXPECT_EQ("^xf", node.input(2));
+
+ } else if (node.name() == "mul_6" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_6", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("xf", node.input(0));
+ EXPECT_EQ("^init", node.input(1));
+ }
+ }
+ EXPECT_EQ(6, count);
+
+ // And that graph evaluation yields the same result.
+ Tensor pi = test::AsScalar<float>(3.14f);
+ Tensor four = test::AsScalar<int32>(4);
+ item.fetch = {"mul_1", "mul_2", "mul_3", "mul_4", "mul_5", "mul_6"};
+ item.feed = {{"xf", pi}, {"yf", pi}, {"xi", four}, {"yi", four}};
+
+ auto tensors_expected = EvaluateFetchNodes(item);
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+ test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
+ test::ExpectTensorEqual<int32>(tensors_expected[2], tensors[2]);
+ test::ExpectTensorEqual<float>(tensors_expected[3], tensors[3]);
+ test::ExpectTensorEqual<float>(tensors_expected[4], tensors[4]);
+ test::ExpectTensorEqual<float>(tensors_expected[5], tensors[5]);
+}
+
} // namespace grappler
} // namespace tensorflow
output.library());
// Specialized and optimized functions should be added to the graph.
- EXPECT_EQ(6, optimized_flib.num_functions());
+ EXPECT_EQ(5, optimized_flib.num_functions());
// MyQuadratic should be specialized once:
// 0. 'quadratic' node in the main graph
const string optimized_0 = "MyQuadratic_specialized_for_quadratic";
// MySquare should be specialized and optimized for 3 instantiations:
- // 1. 'square' node in the main graph
- // 2. 'square' node in the MyQuadratic specialization
- // 3. 'quadratic' node in the MyQuadratic specialization
+ // 1. 'square' node in the main graph
+ // 2. 'square' node in the MyQuadratic specialization
+ // 3*. 'quadratic' node in the MyQuadratic specialization
+ // has identical instantiation context to #2
const string optimized_1 = "MySquare_specialized_for_square";
const string optimized_2 = "MySquare_specialized_for_square_1";
- const string optimized_3 = "MySquare_specialized_for_quadratic";
const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0);
const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1);
const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2);
- const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3);
ASSERT_NE(optimized_func_0, nullptr);
ASSERT_NE(optimized_func_1, nullptr);
ASSERT_NE(optimized_func_2, nullptr);
- ASSERT_NE(optimized_func_3, nullptr);
// Graph should call optimized function.
int count = 0;
if (node.name() == "square" && count++) {
EXPECT_EQ(optimized_2, node.op());
} else if (node.name() == "quadratic" && count++) {
- EXPECT_EQ(optimized_3, node.op());
+ // Share specialized function with the 'square' node.
+ EXPECT_EQ(optimized_2, node.op());
}
}
EXPECT_EQ(2, count);
- const std::vector<const FunctionDef*> optimized_funcs = {
- optimized_func_1, optimized_func_1, optimized_func_3};
+ const std::vector<const FunctionDef*> optimized_funcs = {optimized_func_1,
+ optimized_func_2};
// MyMul should be inlined into all optimized versions of MySquare.
for (const FunctionDef* optimized_func : optimized_funcs) {
return HasParametrizedType(func) || HasParametrizedBody(func);
}
+Status InstantiationTypeParameters(
+ const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+ std::unordered_map<string, DataType>* type_parameters) {
+ if (!type_parameters->empty()) {
+ return errors::InvalidArgument("Type parameters output map must be empty");
+ }
+
+ GrapplerFunctionItemInstantiation instantiation(&func_instantiation_attr);
+
+ const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) {
+ // Check if it's unknown and unresolved type.
+ if (arg.type() == DT_INVALID &&
+ type_parameters->find(arg.type_attr()) == type_parameters->end()) {
+ DataType data_type;
+ TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type));
+ type_parameters->insert({arg.type_attr(), data_type});
+ }
+ return Status::OK();
+ };
+
+ for (const auto& input : func.signature().input_arg())
+ TF_RETURN_IF_ERROR(resolve_type_attr(input));
+ for (const auto& output : func.signature().output_arg())
+ TF_RETURN_IF_ERROR(resolve_type_attr(output));
+
+ return Status::OK();
+}
+
+Status InstantiationBodyParameters(
+ const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+ std::unordered_map<string, AttrValue>* body_parameters) {
+ if (!body_parameters->empty()) {
+ return errors::InvalidArgument("Body parameters output map must be empty");
+ }
+
+ for (const NodeDef& func_body_node : func.node_def()) {
+ for (auto& attr : func_body_node.attr()) {
+ const string& placeholder = attr.second.placeholder();
+
+ if (placeholder.empty() ||
+ body_parameters->find(placeholder) != body_parameters->end()) {
+ continue;
+ }
+
+ auto it = func_instantiation_attr.find(placeholder);
+ if (it != func_instantiation_attr.end()) {
+ body_parameters->emplace(placeholder, it->second);
+ } else {
+ return errors::InvalidArgument("Can't resolve placeholder: ",
+ placeholder);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const AttrValueMap& func_instantiation_attr,
const FunctionLibraryDefinition& flib,
// Check if function has parametrized type or body.
bool IsParametrized(const FunctionDef& func);
+// Resolve function instantiation type parameters from the attributes of the
+// caller node. Return error if type can't be resolved.
+Status InstantiationTypeParameters(
+ const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+ std::unordered_map<string, DataType>* type_parameters);
+
+// Resolve function instantiation body parameters (values for the function body
+// attr placeholders) from the attributes of the caller node. Return error if
+// type can't be resolved.
+Status InstantiationBodyParameters(
+ const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+ std::unordered_map<string, AttrValue>* body_parameters);
+
// Register GrapplerFunctionItem input arg expansion and function body outputs
// in the GrapplerFunctionConnectivity. Use function library definition to
// lookup function body nodes output names and ranges.
// Make a GrapplerFunctionItem from the function definition and function
// instantiation attributes (caller node attributes). Returns error if the given
// function def cannot be converted (e.g. not all attributes are defined).
-Status MakeGrapplerFunctionItem(
- const FunctionDef& func,
- const std::unordered_map<string, AttrValue>& func_instantiation_attr,
- const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item);
+Status MakeGrapplerFunctionItem(const FunctionDef& func,
+ const AttrValueMap& func_instantiation_attr,
+ const FunctionLibraryDefinition& flib,
+ GrapplerFunctionItem* item);
// Make a GrapplerFunction item from the function definition. Function must be
// fully defined (no type or body parametrization).
EXPECT_FALSE(IsParametrized(non_parametrized_func));
}
+TEST_F(FunctionsTest, InstantiationParameters) {
+ // Function definition is invalid, only type/body parameters are important.
+ FunctionDef func = FunctionDefHelper::Create(
+ "ParametrizedFunc",
+ /* inputs */
+ {"input1:A", "input2:B", "input3:float"},
+ /* outputs */
+ {"output1: A", "output2:C"},
+ /* type parameters */
+ {"A: {float, double}", "B: {float, int32}", "C: {float, double}"},
+ /* function body*/
+ {{{"output"}, "FakeOp", {"input1", "input2"}, {{"key", "$key"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"x", "cx:output:0"}, {"y", "cy:output:0"}});
+
+ std::unordered_map<string, AttrValue> func_instantiation_attr;
+ func_instantiation_attr["key"].set_s("key-value");
+ func_instantiation_attr["A"].set_type(DT_FLOAT);
+ func_instantiation_attr["B"].set_type(DT_INT32);
+ func_instantiation_attr["C"].set_type(DT_DOUBLE);
+
+ std::unordered_map<string, DataType> type_parameters;
+ TF_EXPECT_OK(InstantiationTypeParameters(func, func_instantiation_attr,
+ &type_parameters));
+
+ ASSERT_EQ(3, type_parameters.size());
+ EXPECT_EQ(DT_FLOAT, type_parameters["A"]);
+ EXPECT_EQ(DT_INT32, type_parameters["B"]);
+ EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
+
+ std::unordered_map<string, AttrValue> body_parameters;
+ TF_EXPECT_OK(InstantiationBodyParameters(func, func_instantiation_attr,
+ &body_parameters));
+
+ ASSERT_EQ(1, body_parameters.size());
+ EXPECT_EQ("key-value", body_parameters["key"].s());
+}
+
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
GrapplerFunctionConnectivity connectivity;