Shape start(num_dims);
Shape size(num_dims);
+ std::vector<int32_t> squeeze_dims;
for (int axis = 0; axis < num_dims; axis++) {
if (begin_mask & (1 << axis))
start.dim(axis) = 0;
else
start.dim(axis) = begin.at(Index{axis});
+
if (end_mask & (1 << axis))
size.dim(axis) = inputs[0]->getOutputShape(0).dim(axis) - start.dim(axis);
else
size.dim(axis) = end.at(Index{axis}) - start.dim(axis);
+
+ if (shrink_axis_mask & (1 << axis))
+ squeeze_dims.push_back(axis);
}
- std::vector<int32_t> squeeze_dims{shrink_axis_mask - 1};
-
auto slice_outputs = createOp<ops::SliceOp>(ActivationFunctionType_NONE,
- inputs[0]->getOutput(0), start, size);
+ inputs[0]->getOutput(0), start, size);
return createOp<ops::SqueezeOp>(ActivationFunctionType_NONE,
slice_outputs[0]->getOutput(0), squeeze_dims);
}