MatShape inpShape = outShapes[node_proto.input(0)];
DictValue axes = layerParams.get("axes");
bool keepdims = layerParams.get<int>("keepdims");
- MatShape targetShape = inpShape;
+ MatShape targetShape;
+ std::vector<bool> shouldDelete(inpShape.size(), false);
for (int i = 0; i < axes.size(); i++) {
int axis = clamp(axes.get<int>(i), inpShape.size());
- if (keepdims) {
- targetShape[axis] = 1;
- } else {
- targetShape.erase(targetShape.begin() + axis);
- }
+ shouldDelete[axis] = true;
+ }
+ for (int axis = 0; axis < inpShape.size(); ++axis){
+ if (!shouldDelete[axis])
+ targetShape.push_back(inpShape[axis]);
+ else if (keepdims)
+ targetShape.push_back(1);
}
if (inpShape.size() == 3 && axes.size() <= 2)