From 35480a7c444470e753226b9a2d4e766650f90590 Mon Sep 17 00:00:00 2001 From: wuhuikx Date: Fri, 11 Jan 2019 23:52:28 -0800 Subject: [PATCH] Remove StopGradient op when it is inplace in inference (#12152) Summary: For Inference, if the StopGradient op is inpalce, we just remove it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12152 Differential Revision: D13633946 Pulled By: yinghai fbshipit-source-id: 57762bcc37b38a1d39cb4af316ca50bfe961b105 --- caffe2/opt/optimize_ideep.cc | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/caffe2/opt/optimize_ideep.cc b/caffe2/opt/optimize_ideep.cc index e4a9feb..c657cd7 100644 --- a/caffe2/opt/optimize_ideep.cc +++ b/caffe2/opt/optimize_ideep.cc @@ -72,6 +72,34 @@ bool shouldFuseConv(const repr::Conv& conv) { return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false; } +void removeStopGradientForInference(repr::NNModule *nn) { + auto isStopGradientNode = [](const repr::NNGraph::NodeRef& node) { + if (!repr::nn::is(node)) { + return false; + } + auto maybeStopGrad = repr::nn::get(node); + auto maybeStopGradDef = getOpDef(*maybeStopGrad); + return maybeStopGradDef.type() == "StopGradient"; + }; + + auto allNodes = nn->dataFlow.getMutableNodes(); + for (int i = 0; i < allNodes.size(); ++i) { + auto node = allNodes[i]; + if (!isStopGradientNode(node)) { + continue; + } + + auto stopGradInput = repr::nn::getInputs(node).front(); + auto stopGradOutput = repr::nn::getOutputs(node).front(); + auto inputName = repr::nn::get(stopGradInput)->getName(); + auto outputName = repr::nn::get(stopGradOutput)->getName(); + if (inputName == outputName) { + nn->dataFlow.replaceNode(stopGradOutput, stopGradInput); + nn->dataFlow.deleteNode(node); + } + } +} + void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) { // Fusion types: // FUSION_CONV_RELU = 1 @@ -440,6 +468,8 @@ void OptimizeForIdeep( return; } + removeStopGradientForInference(nn); + fuseConvBNAndAffChForIdeep(nn, ws); fuseConvSumForIdeep(nn, ws); -- 2.7.4