bool fix_shape(moco::tf::TFSqueeze *node)
{
- // TODO implement
- throw std::runtime_error("NYI fix_shape TFSqueeze");
+ auto shapedata = node->annot<ShapeInferenceData>();
+ if (shapedata != nullptr)
+ {
+ // shape inference is already done for TFSqueeze
+ return false;
+ }
+
+ auto input = node->input();
+ auto input_shape = input->annot<ShapeInferenceData>();
+ if (input_shape == nullptr)
+ {
+ // Input shape is required for TFSqueeze shape inference
+ return false;
+ }
+
+ // TODO Not sure Squeeze only get input as Tensor
+ // Note that tensor_shape() has assertion in it
+ auto input_tensor_shape = input_shape->tensor_shape();
+
+ auto squeeze_dims_vec = node->squeeze_dims();
+ const std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
+
+ loco::TensorShape node_shape;
+ uint32_t node_rank = 0;
+
+ if (squeeze_dims.empty())
+ {
+ // Remove all dimensions whose value is 1
+ for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
+ {
+ assert(input_tensor_shape.dim(axis).known());
+ auto dim = input_tensor_shape.dim(axis).value();
+ if (dim != 1)
+ {
+ assert(dim > 1);
+ node_shape.rank(++node_rank);
+ node_shape.dim(node_rank - 1) = dim;
+ }
+ }
+ }
+ else
+ {
+ uint32_t input_rank = input_tensor_shape.rank();
+
+ auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
+ if (squeeze_dims.size() >= input_rank)
+ return false;
+ for (auto squeeze_dim : squeeze_dims)
+ {
+ // Negative squeeze dimensions should be resolve before
+ if (squeeze_dim < 0)
+ return false;
+ if (squeeze_dim >= (int64_t)input_rank)
+ return false;
+ }
+ return true;
+ };
+
+ assert(is_valid_squeeze_dims());
+
+ // Remove squeeze dimensions only
+ for (uint32_t axis = 0; axis < input_rank; ++axis)
+ {
+ assert(input_tensor_shape.dim(axis).known());
+ auto dim = input_tensor_shape.dim(axis).value();
+ if (squeeze_dims.find((int64_t)axis) == squeeze_dims.cend())
+ {
+ // Not squeeze dim
+ node_shape.rank(++node_rank);
+ node_shape.dim(node_rank - 1) = dim;
+ }
+ else
+ {
+ // Is squeeze dim
+ assert(dim == 1);
+ // DO NOTHING
+ }
+ }
+ }
+
+ assert(node_shape.rank() > 0);
+
+ auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+ shape_annot->tensor_shape(node_shape);
+ node->annot(std::move(shape_annot));
+
+ LOGGER(l);
+ INFO(l) << "Fix TFSqueeze shape = " << node_shape;
+
+ return true;
}
} // namespace