Fix LogSoftmax for ONNX
authordianlujitao <dianlujitao@gmail.com>
Mon, 20 May 2019 07:46:09 +0000 (15:46 +0800)
committerdianlujitao <dianlujitao@gmail.com>
Mon, 20 May 2019 14:55:31 +0000 (22:55 +0800)
Fix wrong indentation as well while at it

modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 91110cb..b7d289d 100644 (file)
@@ -786,37 +786,42 @@ void ONNXImporter::populateNet(Net dstNet)
             }
             replaceLayerParam(layerParams, "mode", "interpolation");
         }
+        else if (layer_type == "LogSoftmax")
+        {
+            layerParams.type = "Softmax";
+            layerParams.set("log_softmax", true);
+        }
         else
         {
             for (int j = 0; j < node_proto.input_size(); j++) {
                 if (layer_id.find(node_proto.input(j)) == layer_id.end())
                     layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
             }
-         }
-
-         int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
-         layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0)));
-
-
-         std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
-         for (int j = 0; j < node_proto.input_size(); j++) {
-             layerId = layer_id.find(node_proto.input(j));
-             if (layerId != layer_id.end()) {
-                 dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
-                 // Collect input shapes.
-                 shapeIt = outShapes.find(node_proto.input(j));
-                 CV_Assert(shapeIt != outShapes.end());
-                 layerInpShapes.push_back(shapeIt->second);
-             }
-         }
-
-         // Compute shape of output blob for this layer.
-         Ptr<Layer> layer = dstNet.getLayer(id);
-         layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
-         CV_Assert(!layerOutShapes.empty());
-         outShapes[layerParams.name] = layerOutShapes[0];
-     }
- }
+        }
+
+        int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
+        layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0)));
+
+
+        std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
+        for (int j = 0; j < node_proto.input_size(); j++) {
+            layerId = layer_id.find(node_proto.input(j));
+            if (layerId != layer_id.end()) {
+                dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
+                // Collect input shapes.
+                shapeIt = outShapes.find(node_proto.input(j));
+                CV_Assert(shapeIt != outShapes.end());
+                layerInpShapes.push_back(shapeIt->second);
+            }
+        }
+
+        // Compute shape of output blob for this layer.
+        Ptr<Layer> layer = dstNet.getLayer(id);
+        layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
+        CV_Assert(!layerOutShapes.empty());
+        outShapes[layerParams.name] = layerOutShapes[0];
+    }
+}
 
 Net readNetFromONNX(const String& onnxFile)
 {
index bf9f25d..eb30628 100644 (file)
@@ -245,6 +245,12 @@ TEST_P(Test_ONNX_layers, Reshape)
     testONNXModels("unsqueeze");
 }
 
+TEST_P(Test_ONNX_layers, Softmax)
+{
+    testONNXModels("softmax");
+    testONNXModels("log_softmax");
+}
+
 INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
 
 class Test_ONNX_nets : public Test_ONNX_layers {};