[mlir] Transform op for multitile size generation
authorAlex Zinenko <zinenko@google.com>
Thu, 7 Jul 2022 13:56:06 +0000 (15:56 +0200)
committerAlex Zinenko <zinenko@google.com>
Tue, 12 Jul 2022 12:36:28 +0000 (12:36 +0000)
Introduce a structured transform op that emits IR computing the multi-tile
sizes with requested parameters (target size and divisor) for the given
structured op. The sizes may fold to arithmetic constant operations when the
shape is constant. These operations may then be used to call the existing
tiling transformation with a single non-zero dynamic size (i.e. perform
strip-mining) for each of the dimensions separately, thus achieving multi-size
tiling with optional loop interchange. A separate test exercises the entire
script.

Depends On D129217

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129287

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/Dialect/Linalg/multisize-tiling-full.mlir [new file with mode: 0644]
mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir [new file with mode: 0644]
mlir/test/python/dialects/transform_structured_ext.py
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 021158f..39ba998 100644 (file)
@@ -127,6 +127,71 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   }];
 }
 
+def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Emits the IR computing the tile sizes `s1` and `s2` such that:
+
+      - there exists a combination of `n` tiles of size `s1` and `m` tiles of
+        size `s2` that covers the entirety of the iteration space `dimension` of
+        the target structured op;
+      - `s1`, `s2` is less than or equal to `target_size`;
+      - `s1` and `s2` are divisible by `divisor.
+
+    For example, for a dimension of size 54 with target size 12 and divisor 2,
+    this can emit the IR computing the tile size 10, used for 3 tiles, and 12,
+    used for 2 tiles, totally 10*3 + 12*2 = 54. Note that when the divisor does
+    not divide the original dimension size, it is impossible to compute such
+    tile sizes. An assertion is emitted to guard against this in the dynamic
+    case.
+
+    Expects the target size and the divisor to be strictly positive. Folds the
+    IR as much as possible, normally obtaining constant sizes and numbers of
+    tiles for a statically known dimension.
+
+    This does *not* consume the target handle and produces three handles each
+    pointing to single-result index-typed operations (which may be arithmetic
+    constant operations) defining the two respective tile sizes and the product
+    of the first tile size with the number of tiles of that size (useful for
+    splitting the iteration space).
+
+    This operation composes with the regular tiling when applied per-dimension:
+
+    ```mlir
+    %sz1, %sz2, %split = structured.multitile_sizes %target
+                         { target_size = 10, dimension = 1 }
+    %low, %high = structured.split %target after %split { dimension = 1 }
+    %tiled_low = structured.tile %low [0, %sz1]
+    %tiled_high = structured.tile %high [0, %sz2]
+    %common = merge_handles %tiled_low, %tiled_high
+
+    %sz3, %sz4, %split = structured.multitile_size %target
+                         { target_size = 42, dimension = 0 }
+    %sz3r, %sz4r, %splitr = replicate num(%common) %sz3, %sz4, %splitr
+    structured.split %common after %splitr { dimension = 0 }
+    // ...
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                       I64Attr:$dimension,
+                       I64Attr:$target_size,
+                       DefaultValuedAttr<I64Attr, "1">:$divisor);
+  let results = (outs PDL_Operation:$low_size,
+                      PDL_Operation:$high_size,
+                      PDL_Operation:$split_point);
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::linalg::LinalgOp target, 
+        ::llvm::SmallVector<::mlir::Operation *> &results,
+        TransformState &state);
+  }];
+}
+
+
 def PadOp : Op<Transform_Dialect, "structured.pad",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
index 7b45112..47cd647 100644 (file)
@@ -479,6 +479,48 @@ std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
 makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
                     ValueRange allShapeSizes, ValueRange allTileSizes);
 
+/// A description of a multi-size tiling comprising tile sizes and numbers of
+/// tiles, expressed as Values which may or may not be constant. Multi-size
+/// currently means two-size.
+struct MultiSizeSpecification {
+  /// Tile sizes.
+  Value lowTileSize, highTileSize;
+  /// Number of tiles associated with each size.
+  Value lowTripCount, highTripCount;
+};
+
+/// Emits the IR computing the multi-sized tiling specification with two tile
+/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such that
+/// there exist numbers of tiles with these sizes that fully cover the given
+/// iteration space `dimension` of the structured `op`.
+///
+/// The computation is as follows:
+///
+///   b = originalTripCount floordiv sizeDivisor
+///   t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor
+///   d = (b + t - 1) floordiv t
+///   s = (b floordiv d) * sizeDivisor
+///   v = b % d
+///   u = d - v
+///
+/// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of
+/// the corresponding tiles are `u` and `v`, respectively.  Alternatively,
+///
+///   s * u + (s + sizeDivisor) * v == original size,
+///   where s mod sizeDivisor = 0.
+///
+/// Expects all values to be positive. In some cases with the target tile size
+/// sufficiently close to the dimension shape and non-unit divisor, it is
+/// impossible to compute such sizes. If `emitAssertion` is set, also emit the
+/// assertion that size computation succeeded.
+///
+/// Returns the specification consisting of both tile values and the number of
+/// tiles of each size.
+FailureOr<MultiSizeSpecification>
+computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
+                      OpFoldResult targetSize, OpFoldResult divisor,
+                      bool emitAssertions = true);
+
 /// All indices returned by IndexOp should be invariant with respect to tiling.
 /// Therefore, if an operation is tiled, we have to transform the indices
 /// accordingly, i.e. offset them by the values of the corresponding induction
index 819c42f..7d10433 100644 (file)
@@ -8,11 +8,14 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
   MLIRLinalgTransformOpsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAffineDialect
+  MLIRArithmeticDialect
   MLIRIR
   MLIRLinalgDialect
   MLIRLinalgTransforms
   MLIRParser
   MLIRPDLDialect
+  MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRTransformDialect
   MLIRVectorDialect
index ab35b06..f1a9dcd 100644 (file)
@@ -8,12 +8,14 @@
 
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -277,6 +279,55 @@ LogicalResult transform::InterchangeOp::verify() {
 }
 
 //===---------------------------------------------------------------------===//
+// MultiTileSizesOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
+    LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
+  OpBuilder builder(target.getContext());
+  builder.setInsertionPoint(target);
+  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
+  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
+      builder, target, getDimension(), targetSize, divisor);
+  if (failed(spec)) {
+    return emitSilenceableError() << "could not generate tile size computation";
+  }
+
+  Operation *splitPoint =
+      builder
+          .createOrFold<arith::MulIOp>(target.getLoc(), spec->lowTileSize,
+                                       spec->lowTripCount)
+          .getDefiningOp();
+  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
+  Operation *highTileSize = spec->highTileSize.getDefiningOp();
+  assert(lowTileSize && highTileSize && splitPoint &&
+         "tile sizes are not produced by operations");
+  results.reserve(results.size() + 3);
+  results.push_back(lowTileSize);
+  results.push_back(highTileSize);
+  results.push_back(splitPoint);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MultiTileSizesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
+                       transform::TransformMappingResource::get());
+  for (Value result : getResults()) {
+    effects.emplace_back(MemoryEffects::Allocate::get(), result,
+                         transform::TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), result,
+                         transform::TransformMappingResource::get());
+  }
+
+  effects.emplace_back(MemoryEffects::Read::get(),
+                       transform::PayloadIRResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(),
+                       transform::PayloadIRResource::get());
+}
+
+//===---------------------------------------------------------------------===//
 // PadOp
 //===---------------------------------------------------------------------===//
 
@@ -782,6 +833,7 @@ class LinalgTransformDialectExtension
           LinalgTransformDialectExtension> {
 public:
   LinalgTransformDialectExtension() {
+    declareDependentDialect<AffineDialect>();
     declareDependentDialect<arith::ArithmeticDialect>();
     declareDependentDialect<pdl::PDLDialect>();
     declareDependentDialect<scf::SCFDialect>();
index a7524b7..d55876f 100644 (file)
@@ -13,6 +13,7 @@
 #include <utility>
 
 #include "PassDetail.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -82,6 +83,92 @@ void mlir::linalg::transformIndexOps(
   addTileLoopIvsToIndexOpResults(b, op, allIvs);
 }
 
+/// Asserts that the given index-typed value is strictly positive. If the value
+/// is an attribute, asserts at compile time, otherwise emits an assertion
+/// checked at runtime.
+static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
+                                         OpFoldResult value) {
+  if (auto attr = value.dyn_cast<Attribute>()) {
+    assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() &&
+           "expected strictly positive tile size and divisor");
+    return;
+  }
+
+  Value zero = b.create<arith::ConstantIndexOp>(0);
+  Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
+                                            value.get<Value>(), zero);
+  b.create<cf::AssertOp>(
+      condition,
+      b.getStringAttr("expected strictly positive tile size and divisor"));
+}
+
+FailureOr<MultiSizeSpecification>
+mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
+                                    unsigned dimension, OpFoldResult targetSize,
+                                    OpFoldResult divisor, bool emitAssertions) {
+  // Bail out on dimension overflow.
+  if (dimension >= op.getNumLoops())
+    return failure();
+
+  // The code below works only on values.
+  ImplicitLocOpBuilder b(op.getLoc(), builder);
+  if (emitAssertions) {
+    emitIsPositiveIndexAssertion(b, targetSize);
+    emitIsPositiveIndexAssertion(b, divisor);
+  }
+  Value targetSizeValue = materializeOpFoldResult(b, targetSize);
+  Value divisorValue = materializeOpFoldResult(b, divisor);
+
+  // Find the trip count of the iteration space dimension for which the tile
+  // sizes are computed.
+  // TODO: update createFlatListOfOperandDims to return OpFoldResults and avoid
+  // littering by useless constant materialization.
+  SmallVector<Value, 4> allShapes =
+      op.createFlatListOfOperandDims(b, b.getLoc());
+  AffineMap shapesToLoops = op.getShapesToLoopsMap();
+  SmallVector<Value, 4> loopRanges =
+      applyMapToValues(b, op.getLoc(), shapesToLoops, allShapes);
+  Value tripCount = loopRanges[dimension];
+
+  // Compute the tile sizes and the respective numbers of tiles.
+  AffineExpr s0 = b.getAffineSymbolExpr(0);
+  AffineExpr s1 = b.getAffineSymbolExpr(1);
+  AffineExpr s2 = b.getAffineSymbolExpr(2);
+  auto apply = [&](AffineExpr expr, ValueRange values) -> Value {
+    return makeComposedAffineApply(b, b.getLoc(), expr, values);
+  };
+  Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
+  Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
+  Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
+  Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
+  Value v = apply(s0 % s1, {a, d});
+  Value u = apply(s0 - s1, {d, v});
+
+  MultiSizeSpecification spec;
+  spec.lowTileSize = s;
+  spec.highTileSize = apply(s0 + s1, {s, divisorValue});
+  spec.lowTripCount = u;
+  spec.highTripCount = v;
+
+  // If requested, emit the check that the tile sizes are computed correctly.
+  // For example, for iteration dimension size of 15 and the target size 8 it is
+  // impossible to find two tile sizes both divisible by 8 that fully cover the
+  // original space dimension.
+  if (emitAssertions) {
+    AffineExpr s3 = builder.getAffineSymbolExpr(3);
+    Value coveredSize =
+        apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
+                                  spec.highTileSize, spec.highTripCount});
+    Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                           coveredSize, tripCount);
+    b.create<cf::AssertOp>(
+        equals, builder.getStringAttr(
+                    "could not compute dynamic multi-size tile shapes"));
+  }
+
+  return spec;
+}
+
 // Insert a tile `source` into the destination tensor `dest`. The position at
 // which the tile is inserted (as well as size of tile) is taken from a given
 // ExtractSliceOp `sliceOp`.
index b6e078f..95bf2cc 100644 (file)
@@ -110,6 +110,29 @@ class InterchangeOp:
         ip=ip)
 
 
+class MultiTileSizesOp:
+  """Specialization for MultitileSizesOp class."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               dimension: Union[int, IntegerAttr],
+               target_size: Union[int, IntegerAttr],
+               divisor: Optional[Union[int, IntegerAttr]] = None,
+               loc=None,
+               ip=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        pdl.OperationType.get(),
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        dimension=_get_int64_attr(dimension),
+        target_size=_get_int64_attr(target_size),
+        divisor=_get_int64_attr(divisor if divisor else 1),
+        loc=loc,
+        ip=ip)
+
+
 class PadOp:
   """Specialization for PadOp class."""
 
diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
new file mode 100644 (file)
index 0000000..e30a140
--- /dev/null
@@ -0,0 +1,114 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter --canonicalize %s | FileCheck %s
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @linalg_generic : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  // This implements a 2D multisize tiling with target sizes [3, 10].
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @linalg_generic in %arg1
+    %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3}
+    %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10}
+    %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 }
+    %3:2 = transform.structured.tile %2#0 [%1#0]
+    %4:2 = transform.structured.tile %2#1 [%1#1]
+    %5 = merge_handles %3#0, %4#0
+    %tt:3 = replicate num(%5) %t#0, %t#1, %t#2
+    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 }
+    transform.structured.tile %6#0 [0, %tt#0]
+    transform.structured.tile %6#1 [0, %tt#1]
+  }
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
+// CHECK-DAG: #[[$MAP_MIN_4_2:.+]] = affine_map<(d0) -> (-d0 + 4, 2)>
+// CHECK-DAG: #[[$MAP_MIN_16_8:.+]] = affine_map<(d0) -> (-d0 + 16, 8)>
+
+// CHECK-LABEL: @two_d
+// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
+func.func @two_d(%arg0: tensor<10x34xf32>,
+                 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  ins(%arg0: tensor<10x34xf32>)
+  outs(%arg1: tensor<10x34xf32>) {
+  ^bb0(%0: f32, %1: f32):
+    %i = linalg.index 0 : index
+    %j = linalg.index 1 : index
+    %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
+    linalg.yield %call_res : f32
+  } -> tensor<10x34xf32>
+
+  // 2D multi-size tiling should produce for quadrants with sizes
+  //   (2, 8), (2, 9), (3, 8), (3, 9)
+  // respectively, and in this order.
+  // Check the full code for the first quadrant, the data flow for the second
+  // quadrant and only the overall code structure for the remaining quadrants.
+  //
+  // TODO: unfortunately, the canonicalization is insufficiently powerful to
+  // remove the affine min for sizes, leading to dynamic sizes even when tiling
+  // statically-shaped operation with constant tile sizes.
+
+  // CHECK:      %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
+  // CHECK:      scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
+  // CHECK:        %[[SZ1:.+]] = affine.min #[[$MAP_MIN_4_2]](%[[I1]])
+  // CHECK:        %[[INSLICE_1:.+]] = tensor.extract_slice %[[IN]][%[[I1]], 0] [%[[SZ1]], 34] [1, 1]
+  // CHECK:        %[[SZ2:.+]] = affine.min #[[$MAP_MIN_4_2]](%[[I1]])
+  // CHECK:        %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [%[[SZ2]], 34] [1, 1]
+
+  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [%[[SZ1]], 16] [1, 1]
+  // CHECK:        %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
+  // CHECK:          %[[SZ3:.+]] = affine.min #[[$MAP_MIN_16_8]](%[[I2]])
+  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[INSLICE_1]][0, %[[I2]]] [%[[SZ1]], %[[SZ3]]] [1, 1]
+  // CHECK:          %[[SZ4:.+]] = tensor.dim %[[ITERARG_2]]
+  // CHECK:          %[[SZ5:.+]] = affine.min #[[$MAP_MIN_16_8]](%[[I2]])
+  // CHECK:          %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [%[[SZ4]], %[[SZ5]]] [1, 1]
+
+  // CHECK:          %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<?x?xf32>) outs(%[[OUTSLICE_2]] : tensor<?x?xf32>)
+  // CHECK:          %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
+  // CHECK:          scf.yield %[[RESPARTIAL]]
+
+  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [%[[SZ1]], 16] [1, 1]
+  // CHECK:        %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [%[[SZ1]], 18] [1, 1]
+  // CHECK:        scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
+  // CHECK-COUNT-2:  tensor.extract_slice
+  // CHECK:          linalg.generic
+  // CHECK:          tensor.insert_slice
+  // CHECK:          scf.yield
+  // CHECK:        %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
+  // CHECK:        %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]]
+  // CHECK:        scf.yield %[[INSERTED_3]]
+
+  // CHECK:        tensor.insert_slice
+  // CHECK:        tensor.extract_slice
+  // CHECK:        scf.for
+  // CHECK-COUNT-3:  tensor.extract_slice
+  // CHECK:          scf.for
+  // CHECK-COUNT-2:    tensor.extract_slice
+  // CHECK:            linalg.generic
+  // CHECK:            tensor.insert_slice
+  // CHECK:            scf.yield
+  // CHECK:          tensor.insert_slice
+  // CHECK:          tensor.extract_slice
+  // CHECK:          scf.for
+  // CHECK-COUNT-2:    tensor.extract_slice
+  // CHECK:            linalg.generic
+  // CHECK:            tensor.insert_slice
+  // CHECK:            scf.yield
+  // CHECK-COUNT-2:  tensor.insert_slice
+  // CHECK:          scf.yield
+  // CHECK:        %[[RESULT:.+]] = tensor.insert_slice
+  // CHECK:        return %[[RESULT]]
+
+  return %0 : tensor<10x34xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir
new file mode 100644 (file)
index 0000000..08fa934
--- /dev/null
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$MAP13:.+]] = affine_map<() -> (13)>
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  sequence %arg0 {
+    ^bb0(%arg1: !pdl.operation):
+      %0 = pdl_match @pdl_target in %arg1
+      transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 }
+  }
+
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %0 with "transform.dialect"
+  }
+}
+
+// CHECK-LABEL: @multitile_sizes_static
+func.func @multitile_sizes_static(
+  %arg0: tensor<13x34xf32>, %arg1: tensor<34x42xf32>, %arg2: tensor<13x42xf32>)
+    -> tensor<13x42xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<13x34xf32>, tensor<34x42xf32>)
+                     outs(%arg2: tensor<13x42xf32>)
+    -> tensor<13x42xf32>
+  // The first application computes the total size.
+  // CHECK: %{{.*}} = affine.apply #[[$MAP13]]()
+  // CHECK: %[[SIZE:.+]] = affine.apply #[[$MAP13]]()
+  // CHECK: %[[COND:.+]] = arith.cmpi eq, %[[SIZE]], %{{.*}}
+  // CHECK: cf.assert %[[COND]], "could not compute dynamic multi-size tile shapes"
+
+  return %0 : tensor<13x42xf32>
+}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  sequence %arg0 {
+    ^bb0(%arg1: !pdl.operation):
+      %0 = pdl_match @pdl_target in %arg1
+      transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 }
+  }
+
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %0 with "transform.dialect"
+  }
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<()[s0] -> ([[A_IMPL:s0 floordiv 2]])>
+// CHECK: #[[$MAP_T:.+]] = affine_map<() -> (2)>
+// CHECK: #[[$MAP_D:.+]] = affine_map<()[s0] -> ([[D_IMPL:\(s0 floordiv 2 \+ 1\) floordiv 2]])>
+// CHECK: #[[$MAP_S:.+]] = affine_map<()[s0] -> ((([[A_IMPL]]) floordiv ([[D_IMPL]])) * 2)>
+// CHECK: #[[$MAP_V:.+]] = affine_map<()[s0] -> (([[A_IMPL]]) mod ([[D_IMPL]]))>
+// CHECK: #[[$MAP_U:.+]] = affine_map<()[s0] -> ([[D_IMPL]] - ([[A_IMPL]]) mod ([[D_IMPL]]))>
+
+// CHECK-LABEL: @multitile_sizes_dynamic
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %{{.*}}: tensor<?x?xf32>, %{{.*}}: tensor<?x?xf32>)
+func.func @multitile_sizes_dynamic(
+  // For matmul, the extent of the first iteration space dimension is equal to
+  // the size of the first dimension of the first tensor. The indexing map was
+  // folded so there is no map application happening.
+  //
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+  //
+  // The following are the maps as emitted by computeMultiTileSizes.
+  // CHECK: affine.apply #[[$MAP_A]]()[%[[DIM]]]
+  // CHECK: affine.apply #[[$MAP_T]]()
+  // CHECK: affine.apply #[[$MAP_D]]()[%[[DIM]]]
+  // CHECK: affine.apply #[[$MAP_S]]()[%[[DIM]]]
+  // CHECK: affine.apply #[[$MAP_V]]()[%[[DIM]]]
+  // CHECK: affine.apply #[[$MAP_U]]()[%[[DIM]]]
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+  
+  return %0 : tensor<?x?xf32>
+}
index cd4412f..9d2641c 100644 (file)
@@ -55,6 +55,20 @@ def testInterchange():
 
 
 @run
+def testMultitileSizes():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.MultiTileSizesOp(
+        sequence.bodyTarget, dimension=1, target_size=42)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testMultitileSizes
+  # CHECK: transform.sequence
+  # CHECK: transform.structured.multitile_sizes
+  # CHECK-DAG: dimension = 1
+  # CHECK-DAG: target_size = 42
+
+
+@run
 def testPad():
   sequence = transform.SequenceOp()
   with InsertionPoint(sequence.body):
index e86b41c..9cb2147 100644 (file)
@@ -7461,7 +7461,9 @@ cc_library(
     ],
     includes = ["include"],
     deps = [
+        ":AffineDialect",
         ":ArithmeticDialect",
+        ":ControlFlowDialect",
         ":IR",
         ":LinalgDialect",
         ":LinalgTransformOpsIncGen",