}
};
-class MishSubgraph2 : public Subgraph
+// softplus(x) = log(exp(x) + 1)
+class SoftplusSubgraph: public Subgraph
{
public:
- MishSubgraph2()
+ SoftplusSubgraph()
{
int input = addNodeToMatch("");
int exp = addNodeToMatch("Exp", input);
int addVal = addNodeToMatch("");
int add = addNodeToMatch("Add", addVal, exp);
- int log = addNodeToMatch("Log", add);
- int tanh = addNodeToMatch("Tanh", log);
- addNodeToMatch("Mul", input, tanh);
- setFusedNode("Mish", input);
+ addNodeToMatch("Log", add);
+ setFusedNode("Softplus", input);
}
};
-class MishSubgraph3 : public Subgraph
+class SoftplusSubgraph2: public Subgraph
{
public:
- MishSubgraph3()
+ SoftplusSubgraph2()
{
int input = addNodeToMatch("");
int exp = addNodeToMatch("Exp", input);
int addVal = addNodeToMatch("");
int add = addNodeToMatch("Add", exp, addVal);
- int log = addNodeToMatch("Log", add);
- int tanh = addNodeToMatch("Tanh", log);
- addNodeToMatch("Mul", input, tanh);
- setFusedNode("Mish", input);
+ addNodeToMatch("Log", add);
+ setFusedNode("Softplus", input);
}
};
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
subgraphs.push_back(makePtr<ExpandSubgraph>());
+ subgraphs.push_back(makePtr<SoftplusSubgraph>());
+ subgraphs.push_back(makePtr<SoftplusSubgraph2>());
subgraphs.push_back(makePtr<MishSubgraph>());
- subgraphs.push_back(makePtr<MishSubgraph2>());
- subgraphs.push_back(makePtr<MishSubgraph3>());
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
subgraphs.push_back(makePtr<NormalizeSubgraph5>());