uint32_t input_rank = input_tensor_shape.rank();
// Sanity check for 'squeeze_dims'
- // TODO make lambda to hide these assertions
- {
- assert(squeeze_dims.size() < input_rank);
+ auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
+ if (!(squeeze_dims.size() < input_rank))
+ return false;
for (auto squeeze_dim : squeeze_dims)
{
- assert(squeeze_dim >= -(int64_t)input_rank);
- assert(squeeze_dim < (int64_t)input_rank);
+ if (!(squeeze_dim >= -(int64_t)input_rank))
+ return false;
+ if (!(squeeze_dim < (int64_t)input_rank))
+ return false;
}
+ return true;
};
+ if (!is_valid_squeeze_dims())
+ {
+ throw std::runtime_error("Fix shape for TFSqueeze: invalid squeeze dimension");
+ }
+
// Resolve negative squeeze dimension
for (auto squeeze_dim : squeeze_dims)
{