Remove StopGradient op when it is inplace in inference (#12152)
authorwuhuikx <hui.h.wu@intel.com>
Sat, 12 Jan 2019 07:52:28 +0000 (23:52 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 12 Jan 2019 07:55:01 +0000 (23:55 -0800)
Summary:
For Inference, if the StopGradient op is inpalce, we just remove it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12152

Differential Revision: D13633946

Pulled By: yinghai

fbshipit-source-id: 57762bcc37b38a1d39cb4af316ca50bfe961b105

caffe2/opt/optimize_ideep.cc

index e4a9feb..c657cd7 100644 (file)
@@ -72,6 +72,34 @@ bool shouldFuseConv(const repr::Conv& conv) {
   return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false;
 }
 
+void removeStopGradientForInference(repr::NNModule *nn) {
+  auto isStopGradientNode = [](const repr::NNGraph::NodeRef& node) {
+    if (!repr::nn::is<repr::NeuralNetOperator>(node)) {
+      return false;
+    }
+    auto maybeStopGrad = repr::nn::get<repr::NeuralNetOperator>(node);
+    auto maybeStopGradDef = getOpDef(*maybeStopGrad);
+    return maybeStopGradDef.type() == "StopGradient";
+  };
+
+  auto allNodes = nn->dataFlow.getMutableNodes();
+  for (int i = 0; i < allNodes.size(); ++i) {
+    auto node = allNodes[i];
+    if (!isStopGradientNode(node)) {
+      continue;
+    }
+
+    auto stopGradInput = repr::nn::getInputs(node).front();
+    auto stopGradOutput = repr::nn::getOutputs(node).front();
+    auto inputName = repr::nn::get<repr::Tensor>(stopGradInput)->getName();
+    auto outputName = repr::nn::get<repr::Tensor>(stopGradOutput)->getName();
+    if (inputName == outputName) {
+      nn->dataFlow.replaceNode(stopGradOutput, stopGradInput);
+      nn->dataFlow.deleteNode(node);
+    }
+  }
+}
+
 void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
   // Fusion types:
   // FUSION_CONV_RELU = 1
@@ -440,6 +468,8 @@ void OptimizeForIdeep(
     return;
   }
 
+  removeStopGradientForInference(nn);
+
   fuseConvBNAndAffChForIdeep(nn, ws);
 
   fuseConvSumForIdeep(nn, ws);