[mlir][Linalg] Properly propagate transform result in ScalarizeOp
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 27 Dec 2022 14:14:58 +0000 (06:14 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 27 Dec 2022 14:16:55 +0000 (06:16 -0800)
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
mlir/test/Dialect/Linalg/transform-ops.mlir

index 347c530..5660891 100644 (file)
@@ -67,14 +67,14 @@ DiagnosedSilenceableFailure
 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
                                    SmallVectorImpl<Operation *> &results,
                                    transform::TransformState &state) {
-#define DOWNSCALE(trans) \
-    { \
-      FailureOr<LinalgOp> res = tryApply<trans>(target); \
-      if (succeeded(res)) { \
-        results.push_back(*res); \
-        return DiagnosedSilenceableFailure::success(); \
-      } \
-    }
+#define DOWNSCALE(trans)                                                       \
+  {                                                                            \
+    FailureOr<LinalgOp> res = tryApply<trans>(target);                         \
+    if (succeeded(res)) {                                                      \
+      results.push_back(*res);                                                 \
+      return DiagnosedSilenceableFailure::success();                           \
+    }                                                                          \
+  }
 
 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
@@ -986,6 +986,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
   if (failed(maybeTilingResult))
     return emitDefaultDefiniteFailure(target);
 
+  if (target->getNumResults())
+    rewriter.replaceOp(target, maybeTilingResult->replacements);
+  else
+    rewriter.eraseOp(target);
   results.append(maybeTilingResult->tiledOps);
   return DiagnosedSilenceableFailure::success();
 }
index 89c8d32..fbf083c 100644 (file)
@@ -5,8 +5,16 @@ func.func @scalarize(%arg0: tensor<24x12xf32>,
                      %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
   // The op is first tiled by 10 in the first dimension, which creates a
   // dynamic size, and then scalarized, which brings the dimension to static 1.
-  // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12
+  // CHECK: %[[RES_LOOP_1:.*]] = scf.for {{.*}} -> (tensor<24x25xf32>)
+  // CHECK:   %[[RES_LOOP_2:.*]] = scf.for {{.*}} -> (tensor<?x25xf32>)
+  // CHECK:     %[[MM:.*]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12
+  // CHECK:     %[[INS_2:.*]] = tensor.insert_slice %[[MM]] into %{{.*}} [1, 25] [1, 1] : tensor<1x25xf32> into tensor<?x25xf32>
+  // CHECK:     scf.yield %[[INS_2]] : tensor<?x25xf32>
+  // CHECK:   %[[INS_1:.*]] = tensor.insert_slice %[[RES_LOOP_2]] into %{{.*}}, 25] [1, 1] : tensor<?x25xf32> into tensor<24x25xf32>
+  // CHECK:   scf.yield %[[INS_1]] : tensor<24x25xf32>
   %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+
+  // CHECK: return %[[RES_LOOP_1]] : tensor<24x25xf32>
   func.return %0 : tensor<24x25xf32>
 }
 
index 898cce7..64cf3fb 100644 (file)
@@ -8,7 +8,7 @@ transform.sequence failures(propagate) {
 
 //===----------------------------------------------------------------------===//
 // Check that operations are registered correctly through the extension
-// mechanism. Their syntax is generated and requries no additional testing since
+// mechanism. Their syntax is generated and requires no additional testing since
 // we test the generator.
 //===----------------------------------------------------------------------===//