From: A. Unique TensorFlower Date: Mon, 7 May 2018 20:18:33 +0000 (-0700) Subject: Specialize functions only once per unique context. X-Git-Tag: upstream/v1.9.0_rc1~150^2~1^2~66 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=914c971c7b690661754e83549325c5deadd9e62d;p=platform%2Fupstream%2Ftensorflow.git Specialize functions only once per unique context. PiperOrigin-RevId: 195710562 --- diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 1bec908..a44e1ee 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -14,10 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/function_optimizer.h" + #include + #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph_def_util.h" @@ -74,6 +77,73 @@ string UniqueSpecializedFunctionName(const FunctionDef& func, return unique_name; } +// Specialized function instantiation type parameters, body parameters, and +// const inputs. +struct FunctionSpecializationSignature { + string func_name; + std::unordered_map type_parameters; + std::unordered_map body_parameters; + std::unordered_map const_inputs; + + bool operator==(const FunctionSpecializationSignature& other) const { + bool equals = func_name == other.func_name && + type_parameters == other.type_parameters && + const_inputs == other.const_inputs; + + if (!equals) return false; + + // Equality is not defined for AttrValue. + if (body_parameters.size() != other.body_parameters.size()) return false; + + for (const auto& lhs : body_parameters) { + auto it = other.body_parameters.find(lhs.first); + if (it == other.body_parameters.end()) return false; + if (!AreAttrValuesEqual(lhs.second, (*it).second)) return false; + } + + return true; + } + + struct Hash { + uint64 operator()(FunctionSpecializationSignature const& s) const { + uint64 h = Hash64(s.func_name); + + // Use std::map for deterministic iteration order. + + std::map types(s.type_parameters.begin(), + s.type_parameters.end()); + for (const auto& pair : types) { + AttrValue attr_value; + attr_value.set_type(pair.second); + h = Hash64Combine(Hash64(pair.first), h); + h = Hash64Combine(AttrValueHash(attr_value), h); + } + + std::map body(s.body_parameters.begin(), + s.body_parameters.end()); + for (const auto& pair : body) { + h = Hash64Combine(Hash64(pair.first), h); + h = Hash64Combine(AttrValueHash(pair.second), h); + } + + std::map inputs(s.const_inputs.begin(), + s.const_inputs.end()); + for (const auto& pair : inputs) { + h = Hash64Combine(std::hash()(pair.first), h); + h = Hash64Combine(Hash64(pair.second), h); + } + + return h; + } + }; +}; + +struct FunctionSpecialization { + string specialized_func_name; + std::unordered_set const_inputs; + std::unordered_set control_deps; +}; + class FunctionOptimizerContext { public: explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level, @@ -108,6 +178,16 @@ class FunctionOptimizerContext { return gtl::FindWithDefault(inlined_functions_, name, nullptr); } + const FunctionSpecialization* FindFunctionSpecialization( + const FunctionSpecializationSignature& sig) const { + return gtl::FindOrNull(specialized_functions_, sig); + } + + void AddSpecializedFunction(const FunctionSpecializationSignature& sig, + const FunctionSpecialization& specialized_func) { + specialized_functions_.emplace(sig, specialized_func); + } + private: void InitializeTrulyConstNodes(const GrapplerItem& item) { std::unordered_set feed_nodes; @@ -148,6 +228,12 @@ class FunctionOptimizerContext { // Nodes that are Const and not in feed. std::unordered_map truly_const_nodes_; + // Specialized functions. + std::unordered_map + specialized_functions_; + TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext); }; @@ -303,14 +389,34 @@ void RemovePushedDownConstInputs(const std::unordered_set& const_inputs, for (const string& ctrl : control_deps) { if (existing_control_deps.find(ctrl) == existing_control_deps.end()) { - VLOG(3) << "Forward control dependency to function caller node: input=" - << ctrl; + VLOG(3) << "Forward control dependency: input=" << ctrl; specialized_func_node->add_input(ctrl); } } } } +Status InitializeFunctionSpecializationSignature( + const NodeDef& func_node, const FunctionDef& func, + const AttrValueMap& func_attr, const FunctionOptimizerContext& ctx, + FunctionSpecializationSignature* sig) { + sig->func_name = func.signature().name(); + + TF_RETURN_IF_ERROR( + InstantiationTypeParameters(func, func_attr, &sig->type_parameters)); + TF_RETURN_IF_ERROR( + InstantiationBodyParameters(func, func_attr, &sig->body_parameters)); + + for (int i = 0; i < func_node.input_size(); ++i) { + const string& input = func_node.input(i); + if (ctx.IsTrulyConst(input)) { + sig->const_inputs.emplace(i, input); + } + } + + return Status::OK(); +} + Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, FunctionOptimizerContext* ctx, GraphDef* optimized_graph) { @@ -320,6 +426,32 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, const std::unordered_map func_attr( func_node.attr().begin(), func_node.attr().end()); + FunctionSpecializationSignature signature; + TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature( + func_node, func, func_attr, *ctx, &signature)); + + // Check if function was already specialized for identical context. + const FunctionSpecialization* already_specialized = + ctx->FindFunctionSpecialization(signature); + + if (already_specialized) { + VLOG(2) << "Function was already specialized in identical context: " + "specialized_name=" + << already_specialized->specialized_func_name; + + // Add a function call node for the specialized function. + NodeDef* specialized_func_node = optimized_graph->add_node(); + *specialized_func_node = func_node; + specialized_func_node->set_op(already_specialized->specialized_func_name); + + RemovePushedDownConstInputs(already_specialized->const_inputs, + already_specialized->control_deps, + specialized_func_node); + + return Status::OK(); + } + + // Add a new specialized function definition to the library. const auto& flib = ctx->function_library(); // Make a GrapplerFunctionItem and convert it back to FunctionDef after @@ -358,6 +490,10 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, // Update specialized node to remove inputs for pushed down consts. RemovePushedDownConstInputs(const_inputs, control_deps, specialized_func_node); + + ctx->AddSpecializedFunction( + signature, {specialized_func_name, const_inputs, control_deps}); + return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index 147a264..a2dbab3 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -718,5 +718,122 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_PushDownConstInput) { test::ExpectTensorEqual(tensors_expected[0], tensors[0]); } +TEST_F(FunctionOptimizerTest, SpecializeFunction_OncePerUniqueContext) { + using test::function::NDef; + + FunctionOptimizer optimizer(RewriterConfig::DEFAULT); + + // Mark MyMul as noinline. + FunctionDef mul_func = FunctionDefHelper::Create( + "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, int32}"}, + {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "output:z:0"}}); + (*mul_func.mutable_attr())["_noinline"].set_b(true); + std::vector function_library = {mul_func}; + + const Tensor kTwo = test::AsScalar(2.0); + const Tensor kThree = test::AsScalar(3.0); + + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("init", "NoOp", {}, {}, kDevice), + + // Float placeholders. + NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + + // Int32 placeholders. + NDef("xi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice), + NDef("yi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice), + + // Consts. Control inputs has to be attached to specialized func calls. + NDef("two", "Const", {"^init", "^xf"}, + {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice), + NDef("three", "Const", {"^init", "^xf"}, + {{"dtype", DT_FLOAT}, {"value", kThree}}, kDevice), + + // Specialization #1: DT_FLOAT type parameter. + NDef("mul_1", "MyMul", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice), + NDef("mul_2", "MyMul", {"yf", "xf"}, {{"T", DT_FLOAT}}, kDevice), + + // Specialization #2: DT_INT32 type parameter. + NDef("mul_3", "MyMul", {"xi", "yi"}, {{"T", DT_INT32}}, kDevice), + + // Specialization #3: DT_FLOAT type parameter + const input kTwo. + NDef("mul_4", "MyMul", {"xf", "two"}, {{"T", DT_FLOAT}}, kDevice), + NDef("mul_5", "MyMul", {"yf", "two"}, {{"T", DT_FLOAT}}, kDevice), + + // Specialization #4: DT_FLOAT type parameter + const input kThree. + NDef("mul_6", "MyMul", {"three", "xf"}, {{"T", DT_FLOAT}}, kDevice)}, + function_library); + + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + // Make sure that MyMul was specialized once per unique context. + EXPECT_EQ(4, output.library().function_size()); + + // And graph nodes calling specialized functions. + int count = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "mul_1" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_1", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("xf", node.input(0)); + EXPECT_EQ("yf", node.input(1)); + + } else if (node.name() == "mul_2" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_1", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("yf", node.input(0)); + EXPECT_EQ("xf", node.input(1)); + + } else if (node.name() == "mul_3" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_3", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("xi", node.input(0)); + EXPECT_EQ("yi", node.input(1)); + + } else if (node.name() == "mul_4" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_4", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("xf", node.input(0)); + EXPECT_EQ("^init", node.input(1)); + + } else if (node.name() == "mul_5" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_4", node.op()); + ASSERT_EQ(3, node.input_size()); + EXPECT_EQ("yf", node.input(0)); + EXPECT_EQ("^init", node.input(1)); + EXPECT_EQ("^xf", node.input(2)); + + } else if (node.name() == "mul_6" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_6", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("xf", node.input(0)); + EXPECT_EQ("^init", node.input(1)); + } + } + EXPECT_EQ(6, count); + + // And that graph evaluation yields the same result. + Tensor pi = test::AsScalar(3.14f); + Tensor four = test::AsScalar(4); + item.fetch = {"mul_1", "mul_2", "mul_3", "mul_4", "mul_5", "mul_6"}; + item.feed = {{"xf", pi}, {"yf", pi}, {"xi", four}, {"yi", four}}; + + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + + test::ExpectTensorEqual(tensors_expected[0], tensors[0]); + test::ExpectTensorEqual(tensors_expected[1], tensors[1]); + test::ExpectTensorEqual(tensors_expected[2], tensors[2]); + test::ExpectTensorEqual(tensors_expected[3], tensors[3]); + test::ExpectTensorEqual(tensors_expected[4], tensors[4]); + test::ExpectTensorEqual(tensors_expected[5], tensors[5]); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 887a988..8247cce 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -163,30 +163,28 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { output.library()); // Specialized and optimized functions should be added to the graph. - EXPECT_EQ(6, optimized_flib.num_functions()); + EXPECT_EQ(5, optimized_flib.num_functions()); // MyQuadratic should be specialized once: // 0. 'quadratic' node in the main graph const string optimized_0 = "MyQuadratic_specialized_for_quadratic"; // MySquare should be specialized and optimized for 3 instantiations: - // 1. 'square' node in the main graph - // 2. 'square' node in the MyQuadratic specialization - // 3. 'quadratic' node in the MyQuadratic specialization + // 1. 'square' node in the main graph + // 2. 'square' node in the MyQuadratic specialization + // 3*. 'quadratic' node in the MyQuadratic specialization + // has identical instantiation context to #2 const string optimized_1 = "MySquare_specialized_for_square"; const string optimized_2 = "MySquare_specialized_for_square_1"; - const string optimized_3 = "MySquare_specialized_for_quadratic"; const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0); const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1); const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2); - const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3); ASSERT_NE(optimized_func_0, nullptr); ASSERT_NE(optimized_func_1, nullptr); ASSERT_NE(optimized_func_2, nullptr); - ASSERT_NE(optimized_func_3, nullptr); // Graph should call optimized function. int count = 0; @@ -205,13 +203,14 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { if (node.name() == "square" && count++) { EXPECT_EQ(optimized_2, node.op()); } else if (node.name() == "quadratic" && count++) { - EXPECT_EQ(optimized_3, node.op()); + // Share specialized function with the 'square' node. + EXPECT_EQ(optimized_2, node.op()); } } EXPECT_EQ(2, count); - const std::vector optimized_funcs = { - optimized_func_1, optimized_func_1, optimized_func_3}; + const std::vector optimized_funcs = {optimized_func_1, + optimized_func_2}; // MyMul should be inlined into all optimized versions of MySquare. for (const FunctionDef* optimized_func : optimized_funcs) { diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index 79b823f..34603f9 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -417,6 +417,63 @@ bool IsParametrized(const FunctionDef& func) { return HasParametrizedType(func) || HasParametrizedBody(func); } +Status InstantiationTypeParameters( + const FunctionDef& func, const AttrValueMap& func_instantiation_attr, + std::unordered_map* type_parameters) { + if (!type_parameters->empty()) { + return errors::InvalidArgument("Type parameters output map must be empty"); + } + + GrapplerFunctionItemInstantiation instantiation(&func_instantiation_attr); + + const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) { + // Check if it's unknown and unresolved type. + if (arg.type() == DT_INVALID && + type_parameters->find(arg.type_attr()) == type_parameters->end()) { + DataType data_type; + TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type)); + type_parameters->insert({arg.type_attr(), data_type}); + } + return Status::OK(); + }; + + for (const auto& input : func.signature().input_arg()) + TF_RETURN_IF_ERROR(resolve_type_attr(input)); + for (const auto& output : func.signature().output_arg()) + TF_RETURN_IF_ERROR(resolve_type_attr(output)); + + return Status::OK(); +} + +Status InstantiationBodyParameters( + const FunctionDef& func, const AttrValueMap& func_instantiation_attr, + std::unordered_map* body_parameters) { + if (!body_parameters->empty()) { + return errors::InvalidArgument("Body parameters output map must be empty"); + } + + for (const NodeDef& func_body_node : func.node_def()) { + for (auto& attr : func_body_node.attr()) { + const string& placeholder = attr.second.placeholder(); + + if (placeholder.empty() || + body_parameters->find(placeholder) != body_parameters->end()) { + continue; + } + + auto it = func_instantiation_attr.find(placeholder); + if (it != func_instantiation_attr.end()) { + body_parameters->emplace(placeholder, it->second); + } else { + return errors::InvalidArgument("Can't resolve placeholder: ", + placeholder); + } + } + } + + return Status::OK(); +} + Status MakeGrapplerFunctionItem(const FunctionDef& func, const AttrValueMap& func_instantiation_attr, const FunctionLibraryDefinition& flib, diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index d9d71b8..4641bf5 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -191,6 +191,19 @@ bool HasParametrizedBody(const FunctionDef& func); // Check if function has parametrized type or body. bool IsParametrized(const FunctionDef& func); +// Resolve function instantiation type parameters from the attributes of the +// caller node. Return error if type can't be resolved. +Status InstantiationTypeParameters( + const FunctionDef& func, const AttrValueMap& func_instantiation_attr, + std::unordered_map* type_parameters); + +// Resolve function instantiation body parameters (values for the function body +// attr placeholders) from the attributes of the caller node. Return error if +// type can't be resolved. +Status InstantiationBodyParameters( + const FunctionDef& func, const AttrValueMap& func_instantiation_attr, + std::unordered_map* body_parameters); + // Register GrapplerFunctionItem input arg expansion and function body outputs // in the GrapplerFunctionConnectivity. Use function library definition to // lookup function body nodes output names and ranges. @@ -205,10 +218,10 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, // Make a GrapplerFunctionItem from the function definition and function // instantiation attributes (caller node attributes). Returns error if the given // function def cannot be converted (e.g. not all attributes are defined). -Status MakeGrapplerFunctionItem( - const FunctionDef& func, - const std::unordered_map& func_instantiation_attr, - const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item); +Status MakeGrapplerFunctionItem(const FunctionDef& func, + const AttrValueMap& func_instantiation_attr, + const FunctionLibraryDefinition& flib, + GrapplerFunctionItem* item); // Make a GrapplerFunction item from the function definition. Function must be // fully defined (no type or body parametrization). diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index fa6fec7..15d8437 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -54,6 +54,44 @@ TEST_F(FunctionsTest, IsParametrized) { EXPECT_FALSE(IsParametrized(non_parametrized_func)); } +TEST_F(FunctionsTest, InstantiationParameters) { + // Function definition is invalid, only type/body parameters are important. + FunctionDef func = FunctionDefHelper::Create( + "ParametrizedFunc", + /* inputs */ + {"input1:A", "input2:B", "input3:float"}, + /* outputs */ + {"output1: A", "output2:C"}, + /* type parameters */ + {"A: {float, double}", "B: {float, int32}", "C: {float, double}"}, + /* function body*/ + {{{"output"}, "FakeOp", {"input1", "input2"}, {{"key", "$key"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"x", "cx:output:0"}, {"y", "cy:output:0"}}); + + std::unordered_map func_instantiation_attr; + func_instantiation_attr["key"].set_s("key-value"); + func_instantiation_attr["A"].set_type(DT_FLOAT); + func_instantiation_attr["B"].set_type(DT_INT32); + func_instantiation_attr["C"].set_type(DT_DOUBLE); + + std::unordered_map type_parameters; + TF_EXPECT_OK(InstantiationTypeParameters(func, func_instantiation_attr, + &type_parameters)); + + ASSERT_EQ(3, type_parameters.size()); + EXPECT_EQ(DT_FLOAT, type_parameters["A"]); + EXPECT_EQ(DT_INT32, type_parameters["B"]); + EXPECT_EQ(DT_DOUBLE, type_parameters["C"]); + + std::unordered_map body_parameters; + TF_EXPECT_OK(InstantiationBodyParameters(func, func_instantiation_attr, + &body_parameters)); + + ASSERT_EQ(1, body_parameters.size()); + EXPECT_EQ("key-value", body_parameters["key"].s()); +} + TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) { GrapplerFunctionConnectivity connectivity;