[mlir][Linalg][Transform] Add support to let `transform.structured.pack_greedily...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 27 Mar 2023 15:00:43 +0000 (08:00 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 28 Mar 2023 06:37:13 +0000 (23:37 -0700)
This increase the flexibility of the transformation to allow mixed packing / padding specifications.

Differential Revision: https://reviews.llvm.org/D146969

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
mlir/test/Dialect/Linalg/transform-pack-greedily.mlir

index e107911..3b80712 100644 (file)
@@ -588,14 +588,27 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
     Target a Linalg op and rewrite it into packed LinalgOp form by trying to
     infer whether a known suboperation is embedded
 
-    Different packing strategies are applied in order, when one applies
+    Different packing strategies are applied in order, when one applies 
     successfully, the transform returns:
       1. Matmul packing: Try to infer a matmul operation embedded in the target op.
          Specifically, this looks for 2 parallel dimensions that participate in
          an outer-product and 1 reduction dimension.
          These dimensions are referred as (m, n, k) to match canonical matmul
          terminology.
-         The packed sizes for (m, n, k) are specified by `matmul_packed_sizes`.
+         
+         The packed sizes for (m, n, k) are specified by `matmul_packed_sizes`
+         and the optional `matmul_padded_sizes_next_multiple_of`.
+         When an entry `matmul_packed_sizes[i]` is non-0, the corresponding 
+         dimension is packed by `matmul_packed_sizes[i]`.
+         Otherwise, the dimension is merely padded to the next multiple of
+         `matmul_padded_sizes_next_multiple_of[i]`.
+
+         `matmul_padded_sizes_next_multiple_of` is optional and is expected to
+         either be empty or of size `3`, matching the size of `matmul_packed_sizes`.
+         For each individual element of `matmul_packed_sizes` and 
+         `matmul_padded_sizes_next_multiple_of`, only one of them is allowed to
+         be non-zero.
+         
          The ordering of the packed dimensions (mm, nn, kk) is specified by the
          `matmul_inner_dims_order` attribute.
 
@@ -605,10 +618,15 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
       3. An interchange transform is applied to isolate the dimensions to pack as
          the most minor indexing dimensions of the linalg.generic. The most minor
          dimensions are themselves ordered according to `inner_dims_order`.
-      4. Packing is performed by `packed_sizes` and following `inner_dims_order`.
+      4. An elementwise traversal of `matmul_packed_sizes` and
+         `matmul_padded_sizes_next_multiple_of` is performed and for each 
+         dimension `d`, either pack to `matmul_packed_sizes[d]` or pad to the
+         `matmul_padded_sizes_next_multiple_of[d]`.
+      5. Packing/padding is performed by the amounts determined in step 4. and
+         following `inner_dims_order`.
 
     By normalizing the most minor dimensions to `inner_dims_order`, the transform
-    guarantees that packing immediates generates inner dimensions in a desirable
+    guarantees that packing immediately generates inner dimensions in a desirable
     layout.
 
     Outer dimension layout permutations are not controlled by this transform op
@@ -625,15 +643,23 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
   // TODO: Transform_ConcreteOpType<linalg::LinalgOp> needs interface.
   let arguments = (ins TransformHandleTypeInterface:$target,
                    Variadic<PDL_Operation>:$matmul_packed_sizes,
-                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">
-                     :$static_matmul_packed_sizes,
-                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">
-                     :$matmul_inner_dims_order);
+                   ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
+                                 [DenseArrayCount<3>]>:$static_matmul_packed_sizes,
+                   ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
+                                 [Attr<
+                                    Or<[DenseArrayCount<0>.predicate, 
+                                        DenseArrayCount<3>.predicate]>,
+                                        "with 0 or 3 elements"
+                                      >]>
+                                 :$matmul_padded_sizes_next_multiple_of,
+                   ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
+                                 [DenseArrayCount<3>]>:$matmul_inner_dims_order);
   let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op);
 
   let builders = [
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedMatmulPackedSizes,
+                   "ArrayRef<int64_t>":$matmulPaddededSizesNextMultipleOf,
                    CArg<"ArrayRef<int64_t>", "{}">:$matmulDimsInnerDimsOrder)>
   ];
 
@@ -641,7 +667,9 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
     $target
     oilist(
       `matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
-                                                       $static_matmul_packed_sizes)
+                                                         $static_matmul_packed_sizes)
+      (`matmul_padded_sizes_next_multiple_of` `=` 
+        $matmul_padded_sizes_next_multiple_of^)?
       `matmul_inner_dims_order` `=` $matmul_inner_dims_order
     )
     attr-dict
index 6ee0f13..44ef944 100644 (file)
@@ -26,6 +26,7 @@
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -1298,11 +1299,18 @@ LogicalResult transform::PackGreedilyOp::verify() {
                          << " is not a valid permutation";
   }
   // TODO: relax to allow empty once we have another strategy than just matmul.
-  if (getMatmulInnerDimsOrder().size() != 3 ||
-      getMixedMatmulPackedSizes().size() != 3) {
-    return emitOpError() << " needs 3 entries for matmul_packed_sizes and "
-                         << getMatmulInnerDimsOrderAttrName()
-                         << " order for the matmul strategy";
+  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
+    for (auto [s, nmo] :
+         llvm::zip_equal(getMixedMatmulPackedSizes(),
+                         getMatmulPaddedSizesNextMultipleOf())) {
+      std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
+      if (nmo != 0 &&
+          (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
+        return emitOpError() << "at most one of the packed_size and the "
+                                "padded_sizes_next_multiple_of can be nonzero "
+                                "for the matmul strategy";
+      }
+    }
   }
   return success();
 }
@@ -1318,8 +1326,12 @@ LogicalResult transform::PackGreedilyOp::verify() {
 static FailureOr<PackResult>
 packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
                    ArrayRef<OpFoldResult> mnkPackedSizes,
+                   ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
                    ArrayRef<int64_t> mnkOrder) {
   assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
+  assert(mnkPaddedSizesNextMultipleOf.empty() ||
+         mnkPaddedSizesNextMultipleOf.size() == 3 &&
+             "num of packing sizes next multiple should be empty or of size 3");
   assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
   assert(isPermutationVector(mnkOrder) && "expected a permutation");
 
@@ -1334,9 +1346,15 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
   SmallVector<int64_t> mmnnkkPos(numPackedDims);
   for (int64_t i = 0, e = numPackedDims; i < e; ++i)
     mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
-  SmallVector<OpFoldResult> packedSizes(mnkPackedSizes.size());
+  SmallVector<OpFoldResult> packedSizes(numPackedDims);
   for (int64_t i = 0, e = numPackedDims; i < e; ++i)
     packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
+  SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
+  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
+    paddedSizesNextMultipleOf[mnkOrder[i]] =
+        mnkPaddedSizesNextMultipleOf.empty() ? 0
+                                             : mnkPaddedSizesNextMultipleOf[i];
+  }
 
   // 1. Infer dims that are important for matmul.
   FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
@@ -1391,10 +1409,37 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
   // desired outerPerm for each operand.
   // This is left for future work.
 
-  // Add leading zeros to match numLoops.
+  // TODO: this creates too much IR, go use reifyResultShapes.
+  SmallVector<Range, 4> loopRanges =
+      cast<LinalgOp>(genericOp.getOperation())
+          .createLoopRanges(rewriter, genericOp.getLoc());
+
+  // Add leading zeros to match numLoops, we only pack the last 3 dimensions
+  // post interchange.
+  LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
+                                   DBGS() << "paddedSizesNextMultipleOf: ");
+             DBGSNL(););
+  LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
+                                   [](Range r) { llvm::dbgs() << r.size; });
+             DBGSNL(););
   SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
                                                 rewriter.getIndexAttr(0));
-  llvm::append_range(adjustedPackedSizes, packedSizes);
+  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
+    if (paddedSizesNextMultipleOf[i] == 0) {
+      adjustedPackedSizes.push_back(packedSizes[i]);
+      continue;
+    }
+    AffineExpr d0, s0;
+    bindDims(rewriter.getContext(), d0);
+    bindSymbols(rewriter.getContext(), s0);
+    adjustedPackedSizes.push_back(makeComposedFoldedAffineApply(
+        rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
+        {loopRanges[adjustedPackedSizes.size()].size,
+         rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
+  }
+  LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
+                                   DBGS() << "adjustedPackedSizes: ");
+             DBGSNL(););
 
   // TODO: If we wanted to give the genericOp a name after packing, after
   // calling `pack` would be a good time.
@@ -1424,6 +1469,8 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
         /*rewriter=*/rewriter,
         /*linalgOp=*/linalgOp,
         /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
+        /*mnkPaddedSizesNextMultipleOf=*/
+        getMatmulPaddedSizesNextMultipleOf(),
         /*mnkOrder=*/getMatmulInnerDimsOrder());
     if (succeeded(packResult)) {
       results.push_back(packResult->packedLinalgOp);
index 8645fa3..2adfd9a 100644 (file)
@@ -46,3 +46,28 @@ transform.sequence failures(propagate) {
   "transform.structured.multitile_sizes"(%arg0) { target_size = 3, divisor = 2, dimension = 0 }
       : (!pdl.operation) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i32>)
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  // expected-error@below {{not a valid permutation}}
+  transform.structured.pack_greedily %arg0
+      matmul_packed_sizes = [8, 0, 32] 
+      matmul_inner_dims_order = [1, 1, 0]
+    : (!pdl.operation) -> !transform.op<"linalg.generic">
+
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  // expected-error@below {{at most one of the packed_size and the padded_sizes_next_multiple_of can be nonzero}}
+  transform.structured.pack_greedily %arg0
+      matmul_packed_sizes = [1, 1, 1] 
+      matmul_padded_sizes_next_multiple_of = [1, 1, 1] 
+      matmul_inner_dims_order = [0, 1, 2]
+    : (!pdl.operation) -> !transform.op<"linalg.generic">
+
+}
index 544f439..fdb1699 100644 (file)
@@ -226,3 +226,52 @@ transform.sequence failures(propagate) {
       matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
 }
+
+// -----
+
+!A_mk = tensor<1023x255xf32>
+!B_nk = tensor<127x255xf32>
+!C_nm = tensor<127x1023xf32>
+
+#mkn_accesses = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#mkn_trait = {
+  indexing_maps = #mkn_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// Normalized dims are:                     ( k,  m,  n)(kk, mm, nn)
+// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
+// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
+// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+
+// CHECK-LABEL: @matmul_mk_nk_nm(
+func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm {
+  //      CHECK: linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
+  // CHECK-SAME:   ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} 
+  // CHECK-SAME:   ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<1x8x32x130xf32>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<1x128x8x130xf32>)
+  %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) {
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      %d = arith.mulf %a, %b : f32
+      %e = arith.addf %c, %d : f32
+      linalg.yield %e : f32
+  } -> !C_nm
+  return %0 : !C_nm
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
+  transform.structured.pack_greedily %generic
+      // In this spec, the "k" dimension is not packed but rather padded to the
+      // next multiple of 10 (i.e. 130).
+      matmul_packed_sizes = [8, 0, 32] 
+      matmul_padded_sizes_next_multiple_of = [0, 10, 0] 
+      matmul_inner_dims_order = [1, 2, 0]
+    : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
+}