if (hasWeights && hasBias)
CV_CheckEQ(weights.total(), bias.total(), "Incompatible weights/bias blobs");
+ if (weights.total() == 1)
+ {
+ // The total() of bias should be same as weights.
+ if (hasBias)
+ inpBlob.convertTo(outBlob, CV_32F, weights.at<float>(0), bias.at<float>(0));
+ else
+ inpBlob.convertTo(outBlob, CV_32F, weights.at<float>(0));
+ return;
+ }
+
int endAxis;
for (endAxis = axis + 1; endAxis <= inpBlob.dims; ++endAxis)
{
void findBroadAxis(const MatShape& broadShape, const MatShape& outShape, size_t& axis, int& broadAxis)
{
+ // Currently, this function can only complete 1-dimensional expansion of broadShape.
+ // If there are two dimensions in broadShape that need to be expended, it will fail.
const size_t diff = outShape.size() - broadShape.size();
// find the first non-one element of the broadcasting shape
const MatShape& outShape = outShapes[node_proto.input(0)];
size_t axis = 0;
- int broadAxis = -1;
- findBroadAxis(broadShape, outShape, axis, broadAxis);
-
- // if there is a one dimension in the middle that should be broadcasted, broadcast it
- if (broadAxis != -1)
+ if (total(broadShape) != 1)
{
- opencv_onnx::NodeProto concat_node_proto = node_proto;
- const std::string& input1 = concat_node_proto.input(1);
+ // If broadShape is a scalar, we set axis as 0.
+ // Other-wise, we check broadcast is available.
+ int broadAxis = -1;
+ findBroadAxis(broadShape, outShape, axis, broadAxis);
+
+ // if there is a one dimension in the middle that should be broadcasted, broadcast it
+ if (broadAxis != -1)
+ {
+ opencv_onnx::NodeProto concat_node_proto = node_proto;
+ const std::string& input1 = concat_node_proto.input(1);
- expandMid(layerParams.name, concat_node_proto, input1, outShape[broadAxis]);
+ expandMid(layerParams.name, concat_node_proto, input1, outShape[broadAxis]);
- LayerParams concatLP;
- concatLP.name = layerParams.name + "/concat";
- concatLP.set("axis", broadAxis);
- concatLP.type = "Concat";
- concat_node_proto.set_output(0, concatLP.name);
+ LayerParams concatLP;
+ concatLP.name = layerParams.name + "/concat";
+ concatLP.set("axis", broadAxis);
+ concatLP.type = "Concat";
+ concat_node_proto.set_output(0, concatLP.name);
- addLayer(concatLP, concat_node_proto);
- node_proto.set_input(1, concatLP.name);
+ addLayer(concatLP, concat_node_proto);
+ node_proto.set_input(1, concatLP.name);
+ }
}
CV_Assert(axis != outShape.size());