generalize axis for concat fusion
authorYashasSamaga <yashas_2010@yahoo.com>
Sat, 4 Jul 2020 13:27:28 +0000 (18:57 +0530)
committerYashasSamaga <yashas_2010@yahoo.com>
Sat, 4 Jul 2020 13:27:28 +0000 (18:57 +0530)
modules/dnn/src/dnn.cpp

index 0c98c82..f650a71 100644 (file)
@@ -2525,8 +2525,7 @@ struct Net::Impl : public detail::NetImplBase
             // (and so we eliminate the concatenation layer, because the channels
             // are concatenated implicitly).
             Ptr<ConcatLayer> concatLayer = ld.layerInstance.dynamicCast<ConcatLayer>();
-            if( !concatLayer.empty() && concatLayer->axis == 1 && !concatLayer->padding &&
-                ld.outputBlobs.size() == 1 )
+            if( !concatLayer.empty() && !concatLayer->padding && ld.outputBlobs.size() == 1 )
             {
                 Mat& output = ld.outputBlobs[0];
                 UMat umat_output;
@@ -2563,7 +2562,8 @@ struct Net::Impl : public detail::NetImplBase
                 // the concatenation optimization is applied with batch_size > 1.
                 // so, for now, we only apply this optimization in the most popular
                 // case batch_size == 1.
-                if( output.dims == 4 && output.size[0] == 1 )
+                int axis = clamp(concatLayer->axis, output.dims);
+                if( output.total(0, axis) == 1 )
                 {
                     size_t i, ninputs = ld.inputBlobsId.size();
                     std::vector<LayerPin> realinputs(ninputs);
@@ -2602,14 +2602,14 @@ struct Net::Impl : public detail::NetImplBase
                             OpenCLBackendWrapper::update(ld.outputBlobsWrappers, umats);
                         }
 #endif
-                        Range chrange[] = { Range::all(), Range::all(), Range::all(), Range::all() };
+                        std::vector<Range> chrange(output.dims, Range::all());
                         int ofs = 0;
                         for( i = 0; i < ninputs; i++ )
                         {
                             LayerPin pin = realinputs[i];
                             LayerData* inp_i_data = &layers[pin.lid];
-                            int channels_i = ld.inputBlobs[i]->size[1];
-                            chrange[1] = Range(ofs, ofs + channels_i);
+                            int channels_i = ld.inputBlobs[i]->size[axis];
+                            chrange[axis] = Range(ofs, ofs + channels_i);
                             printf_(("\toutput %s(%d) to channels (%d, %d)\n", inp_i_data->layerInstance->name.c_str(),
                                    pin.oid, ofs, ofs + channels_i));
                             ofs += channels_i;