[mlir][linalg] Fix canonicalizer crash for `linalg.generic` with mixed semantics
authorIvan Butygin <ivan.butygin@gmail.com>
Tue, 11 Oct 2022 23:14:35 +0000 (01:14 +0200)
committerIvan Butygin <ivan.butygin@gmail.com>
Mon, 17 Oct 2022 13:20:42 +0000 (15:20 +0200)
`EraseIdentityGenericOp` for `!hasBufferSemantics()` assumed fully tensor semantics and tried to access non-existent return values.

Differential Revision: https://reviews.llvm.org/D135725

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index 586d198..2fcd21c 100644 (file)
@@ -1176,6 +1176,10 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
       return failure();
     }
 
+    // Mixed semantics is not supported yet.
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+
     // Get the argument number of the returned values. That is the operand
     // number to use for replacing uses of this operation.
     SmallVector<Value> returnedArgs;
index 2062c65..8f8f600 100644 (file)
@@ -821,3 +821,28 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) ->
 //  CHECK-SAME:       outs(%[[INIT2]], %[[INIT1]] :
 //       CHECK:   %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor<?x?x?xf32>
 //       CHECK:   return %[[RETURN_CAST]], %[[GENERIC]]#1
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
+  linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel"]
+  } ins(%arg0 : tensor<?xf32>)
+    outs(%arg1 : memref<?xf32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32):
+    linalg.yield %arg2 : f32
+  }
+  return
+}
+
+// There was a crash in EraseIdentityGenericOp for generic with mixed semantics.
+// For now, check generic remained unchanged.
+// CHECK-LABEL: func @identity_mixed
+//  CHECK-SAME:     (%[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
+//       CHECK:     linalg.generic {
+//  CHECK-SAME:    indexing_maps = [#map, #map],
+//  CHECK-SAME:    iterator_types = ["parallel"]
+//  CHECK-SAME:  } ins(%[[ARG1]] : tensor<?xf32>)
+//  CHECK-SAME:    outs(%[[ARG2]] : memref<?xf32>) {