opencv_onnx::GraphProto& net;
};
-class SoftMaxSubgraph : public Subgraph
+class SoftMaxSubgraphBase : public Subgraph
{
public:
- SoftMaxSubgraph() : axis(1)
- {
- int input = addNodeToMatch("");
- int inpExp = addNodeToMatch("Exp", input);
- int sum = addNodeToMatch("ReduceSum", inpExp);
- addNodeToMatch("Div", inpExp, sum);
- setFusedNode("Softmax", input);
- }
+ SoftMaxSubgraphBase() : axis(1), id(-1) {}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
{
- Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
+ CV_Assert(id >= 0 && id < matchedNodesIds.size());
+ Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
for (int i = 0; i < node->attribute_size(); i++)
attr->set_i(axis);
}
-private:
+protected:
int axis;
+ int id;
+};
+
+class SoftMaxSubgraph : public SoftMaxSubgraphBase
+{
+public:
+ SoftMaxSubgraph()
+ {
+ int input = addNodeToMatch("");
+ int inpExp = addNodeToMatch("Exp", input);
+
+ int sum = addNodeToMatch("ReduceSum", inpExp);
+ id = 1;
+
+ addNodeToMatch("Div", inpExp, sum);
+ setFusedNode("Softmax", input);
+ }
+};
+
+class SoftMaxSubgraph2 : public SoftMaxSubgraphBase {
+public:
+ SoftMaxSubgraph2() {
+ int input = addNodeToMatch("");
+
+ int reducemax = addNodeToMatch("ReduceMax", input);
+ id = 0;
+
+ int sub = addNodeToMatch("Sub", input, reducemax);
+ int exp = addNodeToMatch("Exp", sub);
+ int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
+ addNodeToMatch("Div", exp, reducesum);
+ setFusedNode("Softmax", input);
+ }
+};
+
+class LogSoftMaxSubgraph : public SoftMaxSubgraphBase
+{
+public:
+ LogSoftMaxSubgraph()
+ {
+ int input = addNodeToMatch("");
+
+ int reducemax = addNodeToMatch("ReduceMax", input);
+ id = 0;
+
+ int sub_1 = addNodeToMatch("Sub", input, reducemax);
+ int exp = addNodeToMatch("Exp", sub_1);
+ int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
+ int log = addNodeToMatch("Log", reducesum);
+ addNodeToMatch("Sub", sub_1, log);
+ setFusedNode("LogSoftmax", input);
+ }
};
class NormalizeSubgraphBase : public Subgraph
subgraphs.push_back(makePtr<ResizeSubgraph1>());
subgraphs.push_back(makePtr<ResizeSubgraph2>());
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
+ subgraphs.push_back(makePtr<SoftMaxSubgraph2>());
+ subgraphs.push_back(makePtr<LogSoftMaxSubgraph>());
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
subgraphs.push_back(makePtr<NormalizeSubgraph2_2>());