reassociation.emplace_back(ReassociationIndices{});
std::swap(reassociation.back(), curr);
}
- if (!curr.empty())
+ // When the reassociations are not empty, then fold the remaining
+ // unit-dimensions into the last dimension. If the reassociations so far is
+ // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
+ if (!curr.empty() && !reassociation.empty())
reassociation.back().append(curr.begin(), curr.end());
return reassociation;
}
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @subtensor_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
+ %0 = subtensor %arg0[0, 2] [1, 1] [1, 1] : tensor<1x3xf32> to tensor<1x1xf32>
+ return %0 : tensor<1x1xf32>
+}
+// CHECK-LABEL: func @subtensor_unit_dims
+// CHECK: %[[SUBTENSOR:.+]] = subtensor
+// CHECK-SAME: tensor<1x3xf32> to tensor<f32>
+// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] []
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @subtensor_insert_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> {
+ %0 = subtensor_insert %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32>
+ return %0 : tensor<1x3xf32>
+}
+// CHECK-LABEL: func @subtensor_insert_unit_dims
+// CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} []
+// CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]]
+// CHECK-SAME: tensor<f32> into tensor<1x3xf32>
+// CHECK: return %[[RESULT]]