return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false;
}
+void removeStopGradientForInference(repr::NNModule *nn) {
+ auto isStopGradientNode = [](const repr::NNGraph::NodeRef& node) {
+ if (!repr::nn::is<repr::NeuralNetOperator>(node)) {
+ return false;
+ }
+ auto maybeStopGrad = repr::nn::get<repr::NeuralNetOperator>(node);
+ auto maybeStopGradDef = getOpDef(*maybeStopGrad);
+ return maybeStopGradDef.type() == "StopGradient";
+ };
+
+ auto allNodes = nn->dataFlow.getMutableNodes();
+ for (int i = 0; i < allNodes.size(); ++i) {
+ auto node = allNodes[i];
+ if (!isStopGradientNode(node)) {
+ continue;
+ }
+
+ auto stopGradInput = repr::nn::getInputs(node).front();
+ auto stopGradOutput = repr::nn::getOutputs(node).front();
+ auto inputName = repr::nn::get<repr::Tensor>(stopGradInput)->getName();
+ auto outputName = repr::nn::get<repr::Tensor>(stopGradOutput)->getName();
+ if (inputName == outputName) {
+ nn->dataFlow.replaceNode(stopGradOutput, stopGradInput);
+ nn->dataFlow.deleteNode(node);
+ }
+ }
+}
+
void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
// Fusion types:
// FUSION_CONV_RELU = 1
return;
}
+ removeStopGradientForInference(nn);
+
fuseConvBNAndAffChForIdeep(nn, ws);
fuseConvSumForIdeep(nn, ws);