[Intel MKL] Fix for convrnn unit test failure (#19229)
authorNiranjan Hasabnis <niranjan.hasabnis@intel.com>
Wed, 16 May 2018 16:58:05 +0000 (09:58 -0700)
committerRasmus Munk Larsen <rmlarsen@google.com>
Wed, 16 May 2018 16:58:05 +0000 (09:58 -0700)
Adding a fixup pass (+unit test) to handle incorrectly linked Mkl metadata
edges

This graph pass is needed because a graph may have some input Mkl metadata
edges incorrectly setup after node merge and rewrite passes. This could
happen because GetReversePostOrder function may not provide a topologically
sorted order if a graph contains cycles.

Minor style issue left over from earlier commits to the graph pass are also
handled (but not related to the fixup pass).

tensorflow/core/graph/mkl_layout_pass.cc
tensorflow/core/graph/mkl_layout_pass_test.cc

index 72a13d4..b966799 100644 (file)
@@ -2691,14 +2691,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
 
     // If Op has been specifically assigned to a non-CPU device, then No.
     if (!n->assigned_device_name().empty() &&
-        !str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) {
+        !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
       result = false;
       reason = "Op has been assigned a runtime device that is not CPU.";
     }
 
     // If user has specifically assigned this op to a non-CPU device, then No.
     if (!n->def().device().empty() &&
-        !str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) {
+        !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
       result = false;
       reason = "User has assigned a device that is not CPU.";
     }
@@ -2865,9 +2865,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
     return false;
   }
 
-  // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized 
-  // path. The unoptimized path is slow. Thus we dont rewrite the node 
-  // and use default Eigen. But for depth_radius=2, MKL DNN optimized 
+  // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
+  // path. The unoptimized path is slow. Thus we dont rewrite the node
+  // and use default Eigen. But for depth_radius=2, MKL DNN optimized
   // path is taken, i.e., eigen node is rewritten by MKl DNN node.
   static bool LrnRewrite(const Node* n) {
     CHECK_NOTNULL(n);
@@ -2876,13 +2876,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
     CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true);
 
     // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
-    // and use eigen node instead 
+    // and use eigen node instead
     if (depth_radius == 2) {
       return true;
     }
     VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
             << "case is not optimized by Intel MKL, thus using Eigen op"
-            << "for LRN " ; 
+            << "for LRN ";
 
     return false;
   }
@@ -3015,6 +3015,35 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
                                 std::vector<NodeBuilder::NodeOut>* ws_tensors,
                                 bool* are_ws_tensors_added);
 
+  // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
+  // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
+  // 'g'. Returns true is fixup was done; otherwise, it returns false.
+  bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
+    const Edge* e_data, const Edge* e_metadata);
+
+  // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
+  // connected? If not, then fix them. This is needed because a graph may have
+  // some input Mkl metadata edges incorrectly setup after node merge and
+  // rewrite passes. This could happen because GetReversePostOrder function may
+  // not provide topologically sorted order if a graph contains cycles. The
+  // function returns true if at least one Mkl metadata edge for node 'n' was
+  // fixed. Otherwise, it returns false.
+  //
+  // Example:
+  //
+  // X = MklConv2D(_, _, _)
+  // Y = MklConv2DWithBias(_, _, _, _, _, _)
+  // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
+  //
+  // For a graph such as shown above, note that 3rd argument of MklAdd contains
+  // DummyMklTensor. Actually, it should be getting the Mkl metadata from
+  // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
+  // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
+  // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
+  // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
+  // data edges (1st and 2nd arguments of MklAdd).
+  bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
+
   // Functions specific to operators to copy attributes
   // We need operator-specific function to copy attributes because the framework
   // does not provide any generic function for it.
@@ -4242,6 +4271,92 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
 }
 
 ///////////////////////////////////////////////////////////////////////////////
+//              Post-rewrite Mkl metadata fixup pass
+///////////////////////////////////////////////////////////////////////////////
+bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
+    const Edge* e_data, const Edge* e_metadata) {
+  if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
+    return false;
+  }
+
+  Node* n_data = e_data->src();
+  int n_data_op_slot = e_data->src_output();
+  int n_metadata_op_slot = GetTensorMetaDataIndex(n_data_op_slot,
+                                                  n_data->num_outputs());
+
+  // If the source of meta edge is a constant node (producing dummy Mkl metadata
+  // tensor), then we will need to fix.
+  if (IsConstant(e_metadata->src())) {
+    Node* e_metadata_dst = e_metadata->dst();
+    int e_metadata_in_slot = e_metadata->dst_input();
+    CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot,
+                  e_metadata_dst, e_metadata_in_slot));
+
+    (*g)->RemoveEdge(e_metadata);
+    return true;
+  }
+
+  return false;
+}
+
+bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
+    Node* n) {
+  bool result = false;
+
+  // If graph node is not Mkl node, then return.
+  DataType T = DT_INVALID;
+  if (!GetNodeAttr(n->def(), "T", &T).ok() ||
+      !mkl_op_registry::IsMklOp(n->type_string(), T)) {
+    return result;
+  }
+
+  // If it is Mkl node, then check if the input edges to this node that carry
+  // Mkl metadata are linked up correctly with the source node.
+
+  // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
+  // data tensors + n for Mkl metadata tensors). We need to check for correct
+  // connection of n metadata tensors only.
+  int num_data_inputs = n->num_inputs() / 2;
+  for (int idx = 0; idx < num_data_inputs; idx++) {
+    // Get the edge connecting input slot with index (idx).
+    const Edge* e = nullptr;
+    TF_CHECK_OK(n->input_edge(idx, &e));
+
+    // If e is control edge, then skip.
+    if (e->IsControlEdge()) {
+      continue;
+    }
+
+    // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
+    // node, then we don't need to do anything.
+    Node* e_src = e->src();
+    if (GetNodeAttr(e_src->def(), "T", &T).ok() &&
+        mkl_op_registry::IsMklOp(e_src->type_string(), T)) {
+      // Source node for edge 'e' is Mkl node.
+      // Destination node and destination input slot of e is node 'n' and 'idx'
+      // resp.
+      CHECK_EQ(e->dst(), n);
+      CHECK_EQ(e->dst_input(), idx);
+
+      // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
+      // 'e'. For that, let's first get the input slot of 'n' where the meta
+      // edge will feed the value.
+      int e_meta_in_slot = GetTensorMetaDataIndex(e->dst_input(),
+                                                  n->num_inputs());
+      const Edge* e_meta = nullptr;
+      TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
+
+      // Let's check if we need to fix this meta edge.
+      if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
+        result = true;
+      }
+    }
+  }
+
+  return result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
 //              Run function for the pass
 ///////////////////////////////////////////////////////////////////////////////
 
@@ -4307,6 +4422,25 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
 
   DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
 
+  order.clear();
+  GetReversePostOrder(**g, &order);  // This will give us topological sort.
+  for (Node* n : order) {
+    // If node is not an op or it cannot run on CPU device, then skip.
+    if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
+      continue;
+    }
+    if (FixMklMetaDataEdges(g, n)) {
+      string node_name = n->name();
+      string op_name = n->type_string();
+
+      VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
+              << node_name << " with op " << op_name;
+      result = true;
+    }
+  }
+  DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
+            &**g);
+
   return result;
 }
 
index 029cdcf..7645b4a 100644 (file)
@@ -3519,6 +3519,37 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
 }
 
 /////////////////////////////////////////////////////////////////////
+//         Post-rewrite fixup pass test
+
+TEST_F(MklLayoutPassTest, PostRewriteFixUpPass) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'C' op: '_MklConv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " attr { key: 'dilations'        value { list: {i: 1, i:1, i:1, i:1} } }"
+      " input: ['A', 'B', 'M', 'N']}"
+      "node { name: 'D' op: 'Const' "
+      " attr { key: 'dtype' value { type: DT_UINT8 } }"
+      " attr { key: 'value' value { "
+      "    tensor { dtype: DT_UINT8 tensor_shape { dim { size: 1 } } "
+      "    int_val: 0 } } } }"
+      "node { name: 'E' op: '_MklAdd'"
+      " attr {key: 'T'                 value { type: DT_FLOAT } }"
+      " input: ['C', 'A', 'D', 'D']}");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(_MklConv2D);D(Const);E(_MklAdd);"
+            "M(_MklInput);N(_MklInput)|A->C;A->E:1;B->C:1;C->E;C:2->E:2;"
+            "D->E:3;M->C:2;N->C:3");
+}
+
+/////////////////////////////////////////////////////////////////////
 
 static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
   testing::StopTiming();