Specialize functions only once per unique context.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 May 2018 20:18:33 +0000 (13:18 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 00:03:53 +0000 (17:03 -0700)
PiperOrigin-RevId: 195710562

tensorflow/core/grappler/optimizers/function_optimizer.cc
tensorflow/core/grappler/optimizers/function_optimizer_test.cc
tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
tensorflow/core/grappler/utils/functions.cc
tensorflow/core/grappler/utils/functions.h
tensorflow/core/grappler/utils/functions_test.cc

index 1bec908..a44e1ee 100644 (file)
@@ -14,10 +14,13 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
+
 #include <unordered_map>
+
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/graph_def_util.h"
@@ -74,6 +77,73 @@ string UniqueSpecializedFunctionName(const FunctionDef& func,
   return unique_name;
 }
 
+// Specialized function instantiation type parameters, body parameters, and
+// const inputs.
+struct FunctionSpecializationSignature {
+  string func_name;
+  std::unordered_map<string, DataType> type_parameters;
+  std::unordered_map<string, AttrValue> body_parameters;
+  std::unordered_map<int, string> const_inputs;
+
+  bool operator==(const FunctionSpecializationSignature& other) const {
+    bool equals = func_name == other.func_name &&
+                  type_parameters == other.type_parameters &&
+                  const_inputs == other.const_inputs;
+
+    if (!equals) return false;
+
+    // Equality is not defined for AttrValue.
+    if (body_parameters.size() != other.body_parameters.size()) return false;
+
+    for (const auto& lhs : body_parameters) {
+      auto it = other.body_parameters.find(lhs.first);
+      if (it == other.body_parameters.end()) return false;
+      if (!AreAttrValuesEqual(lhs.second, (*it).second)) return false;
+    }
+
+    return true;
+  }
+
+  struct Hash {
+    uint64 operator()(FunctionSpecializationSignature const& s) const {
+      uint64 h = Hash64(s.func_name);
+
+      // Use std::map for deterministic iteration order.
+
+      std::map<string, DataType> types(s.type_parameters.begin(),
+                                       s.type_parameters.end());
+      for (const auto& pair : types) {
+        AttrValue attr_value;
+        attr_value.set_type(pair.second);
+        h = Hash64Combine(Hash64(pair.first), h);
+        h = Hash64Combine(AttrValueHash(attr_value), h);
+      }
+
+      std::map<string, AttrValue> body(s.body_parameters.begin(),
+                                       s.body_parameters.end());
+      for (const auto& pair : body) {
+        h = Hash64Combine(Hash64(pair.first), h);
+        h = Hash64Combine(AttrValueHash(pair.second), h);
+      }
+
+      std::map<int, string> inputs(s.const_inputs.begin(),
+                                   s.const_inputs.end());
+      for (const auto& pair : inputs) {
+        h = Hash64Combine(std::hash<int>()(pair.first), h);
+        h = Hash64Combine(Hash64(pair.second), h);
+      }
+
+      return h;
+    }
+  };
+};
+
+struct FunctionSpecialization {
+  string specialized_func_name;
+  std::unordered_set<string> const_inputs;
+  std::unordered_set<string> control_deps;
+};
+
 class FunctionOptimizerContext {
  public:
   explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level,
@@ -108,6 +178,16 @@ class FunctionOptimizerContext {
     return gtl::FindWithDefault(inlined_functions_, name, nullptr);
   }
 
+  const FunctionSpecialization* FindFunctionSpecialization(
+      const FunctionSpecializationSignature& sig) const {
+    return gtl::FindOrNull(specialized_functions_, sig);
+  }
+
+  void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
+                              const FunctionSpecialization& specialized_func) {
+    specialized_functions_.emplace(sig, specialized_func);
+  }
+
  private:
   void InitializeTrulyConstNodes(const GrapplerItem& item) {
     std::unordered_set<string> feed_nodes;
@@ -148,6 +228,12 @@ class FunctionOptimizerContext {
   // Nodes that are Const and not in feed.
   std::unordered_map<string, const NodeDef*> truly_const_nodes_;
 
+  // Specialized functions.
+  std::unordered_map<FunctionSpecializationSignature,
+                     const FunctionSpecialization,
+                     FunctionSpecializationSignature::Hash>
+      specialized_functions_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
 };
 
@@ -303,14 +389,34 @@ void RemovePushedDownConstInputs(const std::unordered_set<string>& const_inputs,
 
     for (const string& ctrl : control_deps) {
       if (existing_control_deps.find(ctrl) == existing_control_deps.end()) {
-        VLOG(3) << "Forward control dependency to function caller node: input="
-                << ctrl;
+        VLOG(3) << "Forward control dependency: input=" << ctrl;
         specialized_func_node->add_input(ctrl);
       }
     }
   }
 }
 
+Status InitializeFunctionSpecializationSignature(
+    const NodeDef& func_node, const FunctionDef& func,
+    const AttrValueMap& func_attr, const FunctionOptimizerContext& ctx,
+    FunctionSpecializationSignature* sig) {
+  sig->func_name = func.signature().name();
+
+  TF_RETURN_IF_ERROR(
+      InstantiationTypeParameters(func, func_attr, &sig->type_parameters));
+  TF_RETURN_IF_ERROR(
+      InstantiationBodyParameters(func, func_attr, &sig->body_parameters));
+
+  for (int i = 0; i < func_node.input_size(); ++i) {
+    const string& input = func_node.input(i);
+    if (ctx.IsTrulyConst(input)) {
+      sig->const_inputs.emplace(i, input);
+    }
+  }
+
+  return Status::OK();
+}
+
 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
                           FunctionOptimizerContext* ctx,
                           GraphDef* optimized_graph) {
@@ -320,6 +426,32 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
   const std::unordered_map<string, AttrValue> func_attr(
       func_node.attr().begin(), func_node.attr().end());
 
+  FunctionSpecializationSignature signature;
+  TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
+      func_node, func, func_attr, *ctx, &signature));
+
+  // Check if function was already specialized for identical context.
+  const FunctionSpecialization* already_specialized =
+      ctx->FindFunctionSpecialization(signature);
+
+  if (already_specialized) {
+    VLOG(2) << "Function was already specialized in identical context: "
+               "specialized_name="
+            << already_specialized->specialized_func_name;
+
+    // Add a function call node for the specialized function.
+    NodeDef* specialized_func_node = optimized_graph->add_node();
+    *specialized_func_node = func_node;
+    specialized_func_node->set_op(already_specialized->specialized_func_name);
+
+    RemovePushedDownConstInputs(already_specialized->const_inputs,
+                                already_specialized->control_deps,
+                                specialized_func_node);
+
+    return Status::OK();
+  }
+
+  // Add a new specialized function definition to the library.
   const auto& flib = ctx->function_library();
 
   // Make a GrapplerFunctionItem and convert it back to FunctionDef after
@@ -358,6 +490,10 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
   // Update specialized node to remove inputs for pushed down consts.
   RemovePushedDownConstInputs(const_inputs, control_deps,
                               specialized_func_node);
+
+  ctx->AddSpecializedFunction(
+      signature, {specialized_func_name, const_inputs, control_deps});
+
   return Status::OK();
 }
 
index 147a264..a2dbab3 100644 (file)
@@ -718,5 +718,122 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_PushDownConstInput) {
   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
 }
 
+TEST_F(FunctionOptimizerTest, SpecializeFunction_OncePerUniqueContext) {
+  using test::function::NDef;
+
+  FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
+
+  // Mark MyMul as noinline.
+  FunctionDef mul_func = FunctionDefHelper::Create(
+      "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, int32}"},
+      {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"z", "output:z:0"}});
+  (*mul_func.mutable_attr())["_noinline"].set_b(true);
+  std::vector<FunctionDef> function_library = {mul_func};
+
+  const Tensor kTwo = test::AsScalar<float>(2.0);
+  const Tensor kThree = test::AsScalar<float>(3.0);
+
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("init", "NoOp", {}, {}, kDevice),
+
+       // Float placeholders.
+       NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+       NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+
+       // Int32 placeholders.
+       NDef("xi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
+       NDef("yi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
+
+       // Consts. Control inputs has to be attached to specialized func calls.
+       NDef("two", "Const", {"^init", "^xf"},
+            {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
+       NDef("three", "Const", {"^init", "^xf"},
+            {{"dtype", DT_FLOAT}, {"value", kThree}}, kDevice),
+
+       // Specialization #1: DT_FLOAT type parameter.
+       NDef("mul_1", "MyMul", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
+       NDef("mul_2", "MyMul", {"yf", "xf"}, {{"T", DT_FLOAT}}, kDevice),
+
+       // Specialization #2: DT_INT32 type parameter.
+       NDef("mul_3", "MyMul", {"xi", "yi"}, {{"T", DT_INT32}}, kDevice),
+
+       // Specialization #3: DT_FLOAT type parameter + const input kTwo.
+       NDef("mul_4", "MyMul", {"xf", "two"}, {{"T", DT_FLOAT}}, kDevice),
+       NDef("mul_5", "MyMul", {"yf", "two"}, {{"T", DT_FLOAT}}, kDevice),
+
+       // Specialization #4: DT_FLOAT type parameter + const input kThree.
+       NDef("mul_6", "MyMul", {"three", "xf"}, {{"T", DT_FLOAT}}, kDevice)},
+      function_library);
+
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  // Make sure that MyMul was specialized once per unique context.
+  EXPECT_EQ(4, output.library().function_size());
+
+  // And graph nodes calling specialized functions.
+  int count = 0;
+  for (const NodeDef& node : output.node()) {
+    if (node.name() == "mul_1" && count++) {
+      EXPECT_EQ("MyMul_specialized_for_mul_1", node.op());
+      ASSERT_EQ(2, node.input_size());
+      EXPECT_EQ("xf", node.input(0));
+      EXPECT_EQ("yf", node.input(1));
+
+    } else if (node.name() == "mul_2" && count++) {
+      EXPECT_EQ("MyMul_specialized_for_mul_1", node.op());
+      ASSERT_EQ(2, node.input_size());
+      EXPECT_EQ("yf", node.input(0));
+      EXPECT_EQ("xf", node.input(1));
+
+    } else if (node.name() == "mul_3" && count++) {
+      EXPECT_EQ("MyMul_specialized_for_mul_3", node.op());
+      ASSERT_EQ(2, node.input_size());
+      EXPECT_EQ("xi", node.input(0));
+      EXPECT_EQ("yi", node.input(1));
+
+    } else if (node.name() == "mul_4" && count++) {
+      EXPECT_EQ("MyMul_specialized_for_mul_4", node.op());
+      ASSERT_EQ(2, node.input_size());
+      EXPECT_EQ("xf", node.input(0));
+      EXPECT_EQ("^init", node.input(1));
+
+    } else if (node.name() == "mul_5" && count++) {
+      EXPECT_EQ("MyMul_specialized_for_mul_4", node.op());
+      ASSERT_EQ(3, node.input_size());
+      EXPECT_EQ("yf", node.input(0));
+      EXPECT_EQ("^init", node.input(1));
+      EXPECT_EQ("^xf", node.input(2));
+
+    } else if (node.name() == "mul_6" && count++) {
+      EXPECT_EQ("MyMul_specialized_for_mul_6", node.op());
+      ASSERT_EQ(2, node.input_size());
+      EXPECT_EQ("xf", node.input(0));
+      EXPECT_EQ("^init", node.input(1));
+    }
+  }
+  EXPECT_EQ(6, count);
+
+  // And that graph evaluation yields the same result.
+  Tensor pi = test::AsScalar<float>(3.14f);
+  Tensor four = test::AsScalar<int32>(4);
+  item.fetch = {"mul_1", "mul_2", "mul_3", "mul_4", "mul_5", "mul_6"};
+  item.feed = {{"xf", pi}, {"yf", pi}, {"xi", four}, {"yi", four}};
+
+  auto tensors_expected = EvaluateFetchNodes(item);
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+
+  test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+  test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
+  test::ExpectTensorEqual<int32>(tensors_expected[2], tensors[2]);
+  test::ExpectTensorEqual<float>(tensors_expected[3], tensors[3]);
+  test::ExpectTensorEqual<float>(tensors_expected[4], tensors[4]);
+  test::ExpectTensorEqual<float>(tensors_expected[5], tensors[5]);
+}
+
 }  // namespace grappler
 }  // namespace tensorflow
index 887a988..8247cce 100644 (file)
@@ -163,30 +163,28 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
                                            output.library());
 
   // Specialized and optimized functions should be added to the graph.
-  EXPECT_EQ(6, optimized_flib.num_functions());
+  EXPECT_EQ(5, optimized_flib.num_functions());
 
   // MyQuadratic should be specialized once:
   //   0. 'quadratic' node in the main graph
   const string optimized_0 = "MyQuadratic_specialized_for_quadratic";
 
   // MySquare should be specialized and optimized for 3 instantiations:
-  //   1. 'square' node in the main graph
-  //   2. 'square' node in the MyQuadratic specialization
-  //   3. 'quadratic' node in the MyQuadratic specialization
+  //   1.  'square' node in the main graph
+  //   2.  'square' node in the MyQuadratic specialization
+  //   3*. 'quadratic' node in the MyQuadratic specialization
+  //        has identical instantiation context to #2
 
   const string optimized_1 = "MySquare_specialized_for_square";
   const string optimized_2 = "MySquare_specialized_for_square_1";
-  const string optimized_3 = "MySquare_specialized_for_quadratic";
 
   const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0);
   const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1);
   const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2);
-  const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3);
 
   ASSERT_NE(optimized_func_0, nullptr);
   ASSERT_NE(optimized_func_1, nullptr);
   ASSERT_NE(optimized_func_2, nullptr);
-  ASSERT_NE(optimized_func_3, nullptr);
 
   // Graph should call optimized function.
   int count = 0;
@@ -205,13 +203,14 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
     if (node.name() == "square" && count++) {
       EXPECT_EQ(optimized_2, node.op());
     } else if (node.name() == "quadratic" && count++) {
-      EXPECT_EQ(optimized_3, node.op());
+      // Share specialized function with the 'square' node.
+      EXPECT_EQ(optimized_2, node.op());
     }
   }
   EXPECT_EQ(2, count);
 
-  const std::vector<const FunctionDef*> optimized_funcs = {
-      optimized_func_1, optimized_func_1, optimized_func_3};
+  const std::vector<const FunctionDef*> optimized_funcs = {optimized_func_1,
+                                                           optimized_func_2};
 
   // MyMul should be inlined into all optimized versions of MySquare.
   for (const FunctionDef* optimized_func : optimized_funcs) {
index 79b823f..34603f9 100644 (file)
@@ -417,6 +417,63 @@ bool IsParametrized(const FunctionDef& func) {
   return HasParametrizedType(func) || HasParametrizedBody(func);
 }
 
+Status InstantiationTypeParameters(
+    const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+    std::unordered_map<string, DataType>* type_parameters) {
+  if (!type_parameters->empty()) {
+    return errors::InvalidArgument("Type parameters output map must be empty");
+  }
+
+  GrapplerFunctionItemInstantiation instantiation(&func_instantiation_attr);
+
+  const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) {
+    // Check if it's unknown and unresolved type.
+    if (arg.type() == DT_INVALID &&
+        type_parameters->find(arg.type_attr()) == type_parameters->end()) {
+      DataType data_type;
+      TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type));
+      type_parameters->insert({arg.type_attr(), data_type});
+    }
+    return Status::OK();
+  };
+
+  for (const auto& input : func.signature().input_arg())
+    TF_RETURN_IF_ERROR(resolve_type_attr(input));
+  for (const auto& output : func.signature().output_arg())
+    TF_RETURN_IF_ERROR(resolve_type_attr(output));
+
+  return Status::OK();
+}
+
+Status InstantiationBodyParameters(
+    const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+    std::unordered_map<string, AttrValue>* body_parameters) {
+  if (!body_parameters->empty()) {
+    return errors::InvalidArgument("Body parameters output map must be empty");
+  }
+
+  for (const NodeDef& func_body_node : func.node_def()) {
+    for (auto& attr : func_body_node.attr()) {
+      const string& placeholder = attr.second.placeholder();
+
+      if (placeholder.empty() ||
+          body_parameters->find(placeholder) != body_parameters->end()) {
+        continue;
+      }
+
+      auto it = func_instantiation_attr.find(placeholder);
+      if (it != func_instantiation_attr.end()) {
+        body_parameters->emplace(placeholder, it->second);
+      } else {
+        return errors::InvalidArgument("Can't resolve placeholder: ",
+                                       placeholder);
+      }
+    }
+  }
+
+  return Status::OK();
+}
+
 Status MakeGrapplerFunctionItem(const FunctionDef& func,
                                 const AttrValueMap& func_instantiation_attr,
                                 const FunctionLibraryDefinition& flib,
index d9d71b8..4641bf5 100644 (file)
@@ -191,6 +191,19 @@ bool HasParametrizedBody(const FunctionDef& func);
 // Check if function has parametrized type or body.
 bool IsParametrized(const FunctionDef& func);
 
+// Resolve function instantiation type parameters from the attributes of the
+// caller node. Return error if type can't be resolved.
+Status InstantiationTypeParameters(
+    const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+    std::unordered_map<string, DataType>* type_parameters);
+
+// Resolve function instantiation body parameters (values for the function body
+// attr placeholders) from the attributes of the caller node. Return error if
+// type can't be resolved.
+Status InstantiationBodyParameters(
+    const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+    std::unordered_map<string, AttrValue>* body_parameters);
+
 // Register GrapplerFunctionItem input arg expansion and function body outputs
 // in the GrapplerFunctionConnectivity. Use function library definition to
 // lookup function body nodes output names and ranges.
@@ -205,10 +218,10 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
 // Make a GrapplerFunctionItem from the function definition and function
 // instantiation attributes (caller node attributes). Returns error if the given
 // function def cannot be converted (e.g. not all attributes are defined).
-Status MakeGrapplerFunctionItem(
-    const FunctionDef& func,
-    const std::unordered_map<string, AttrValue>& func_instantiation_attr,
-    const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item);
+Status MakeGrapplerFunctionItem(const FunctionDef& func,
+                                const AttrValueMap& func_instantiation_attr,
+                                const FunctionLibraryDefinition& flib,
+                                GrapplerFunctionItem* item);
 
 // Make a GrapplerFunction item from the function definition. Function must be
 // fully defined (no type or body parametrization).
index fa6fec7..15d8437 100644 (file)
@@ -54,6 +54,44 @@ TEST_F(FunctionsTest, IsParametrized) {
   EXPECT_FALSE(IsParametrized(non_parametrized_func));
 }
 
+TEST_F(FunctionsTest, InstantiationParameters) {
+  // Function definition is invalid, only type/body parameters are important.
+  FunctionDef func = FunctionDefHelper::Create(
+      "ParametrizedFunc",
+      /* inputs */
+      {"input1:A", "input2:B", "input3:float"},
+      /* outputs */
+      {"output1: A", "output2:C"},
+      /* type parameters */
+      {"A: {float, double}", "B: {float, int32}", "C: {float, double}"},
+      /* function body*/
+      {{{"output"}, "FakeOp", {"input1", "input2"}, {{"key", "$key"}}}},
+      /* Mapping between function returns and function node outputs. */
+      {{"x", "cx:output:0"}, {"y", "cy:output:0"}});
+
+  std::unordered_map<string, AttrValue> func_instantiation_attr;
+  func_instantiation_attr["key"].set_s("key-value");
+  func_instantiation_attr["A"].set_type(DT_FLOAT);
+  func_instantiation_attr["B"].set_type(DT_INT32);
+  func_instantiation_attr["C"].set_type(DT_DOUBLE);
+
+  std::unordered_map<string, DataType> type_parameters;
+  TF_EXPECT_OK(InstantiationTypeParameters(func, func_instantiation_attr,
+                                           &type_parameters));
+
+  ASSERT_EQ(3, type_parameters.size());
+  EXPECT_EQ(DT_FLOAT, type_parameters["A"]);
+  EXPECT_EQ(DT_INT32, type_parameters["B"]);
+  EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
+
+  std::unordered_map<string, AttrValue> body_parameters;
+  TF_EXPECT_OK(InstantiationBodyParameters(func, func_instantiation_attr,
+                                           &body_parameters));
+
+  ASSERT_EQ(1, body_parameters.size());
+  EXPECT_EQ("key-value", body_parameters["key"].s());
+}
+
 TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
   GrapplerFunctionConnectivity connectivity;