{
setParamsFrom(params);
axis = params.get<int>("axis", 1);
+ num_split = params.get<int>("num_split", 0);
if (params.has("slice_point"))
{
CV_Assert(!params.has("begin") && !params.has("size") && !params.has("end"));
else // Divide input blob on equal parts by axis.
{
CV_Assert(0 <= axis && axis < inpShape.size());
- CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
- inpShape[axis] /= requiredOutputs;
- outputs.resize(requiredOutputs, inpShape);
+ int splits = num_split ? num_split : requiredOutputs;
+ CV_Assert(splits > 0 && inpShape[axis] % splits == 0);
+ inpShape[axis] /= splits;
+ outputs.resize(splits, inpShape);
}
return false;
}
axis = toNCHW(axis);
layerParams.set("axis", axis);
+ if (hasLayerAttr(layer, "num_split"))
+ layerParams.set("num_split", getLayerAttr(layer, "num_split").i());
+
int id = dstNet.addLayer(name, "Slice", layerParams);
layer_id[name] = id;
runTensorFlowNet("l2_normalize_3d");
}
+TEST_P(Test_TensorFlow_layers, Split)
+{
+ runTensorFlowNet("split");
+}
+
class Test_TensorFlow_nets : public DNNTestLayer {};
TEST_P(Test_TensorFlow_nets, MobileNet_SSD)