[mlir][tensor] Add producer fusion for tensor.unpack op.
authorHanhan Wang <hanchung@google.com>
Fri, 6 Jan 2023 18:49:08 +0000 (10:49 -0800)
committerHanhan Wang <hanchung@google.com>
Fri, 6 Jan 2023 22:13:11 +0000 (14:13 -0800)
Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D141151

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/test/Dialect/Linalg/transform-op-fuse.mlir

index 46324672ad2f2ff9e45f7129a1064347352b5d16..7eced90aac548926ff208e484ced6da1d525f685 100644 (file)
@@ -425,6 +425,15 @@ struct UnPackOpTiling
     resultSizes = llvm::to_vector(sizes);
     return success();
   }
+
+  FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
+                                           unsigned resultNumber,
+                                           ArrayRef<OpFoldResult> offsets,
+                                           ArrayRef<OpFoldResult> sizes) const {
+    return getTiledImplementation(op, b, offsets, sizes)
+        .back()
+        ->getResult(resultNumber);
+  }
 };
 
 } // namespace
index 7cbda2d1c85d708451ec75ff572c3df379cab10e..580ad597ef30d9027d58887089e85bacbba611bb 100644 (file)
@@ -91,3 +91,26 @@ transform.sequence failures(propagate) {
   %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]}
   %2, %loops_2 = transform.structured.tile %1 [0, 4]
 }
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_elemwise
+// CHECK:         %[[RES:.*]] = scf.for
+// CHECK:           scf.for
+// CHECK:             tensor.unpack
+// CHECK:             linalg.elemwise_unary
+// CHECK:         return %[[RES]]
+func.func @unpack_elemwise(%arg0: tensor<16x48x8x8xf32>, %arg1: tensor<128x384xf32>) -> tensor<128x384xf32> {
+  %0 = tensor.empty() : tensor<128x384xf32>
+  %1 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
+      : tensor<16x48x8x8xf32> -> tensor<128x384xf32>
+  %2 = linalg.elemwise_unary ins(%1: tensor<128x384xf32>)
+                             outs(%arg1: tensor<128x384xf32>) -> tensor<128x384xf32>
+  return %2 : tensor<128x384xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1
+  %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]}
+}