From 26e916334ebc3cb34f1c020e00c731bd60b0323a Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Fri, 7 May 2021 05:59:05 +0000 Subject: [PATCH] [mlir][linalg] Add IndexedGenericOp to GenericOp canonicalization. Replace all `linalg.indexed_generic` ops by `linalg.generic` ops that access the iteration indices using the `linalg.index` op. Differential Revision: https://reviews.llvm.org/D101612 --- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 - mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 43 +++++++++++++++++++++- mlir/test/Dialect/Linalg/bufferize.mlir | 4 +- .../Linalg/canonicalize-duplicate-inputs.mlir | 4 +- mlir/test/Dialect/Linalg/canonicalize.mlir | 33 +++++++++++++++++ .../Dialect/Linalg/fusion-indexed-generic.mlir | 12 +++--- mlir/test/Dialect/Linalg/fusion-tensor.mlir | 34 +++++++---------- mlir/test/Dialect/Linalg/reshape_fusion.mlir | 37 +++++++++++-------- mlir/test/Dialect/Linalg/tile-indexed-generic.mlir | 30 +++++++++------ mlir/test/Dialect/Linalg/tile-tensors.mlir | 2 +- 10 files changed, 142 insertions(+), 58 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 8e323bd..79b7a7d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -811,7 +811,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { ]; let verifier = [{ return ::verify(*this); }]; - let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a43c4ac..01c240f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -735,6 +735,48 @@ static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } +namespace { + +/// Replace indexed_generic ops by generic ops that access the iteration indices +/// using index operation calls. +struct ConvertIndexedToGenericOp : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IndexedGenericOp indexedOp, + PatternRewriter &rewriter) const override { + // Replace all uses of the index block arguments. + BlockAndValueMapping bvm; + if (Block *body = indexedOp.getBody()) { + rewriter.setInsertionPointToStart(body); + for (const auto &en : llvm::enumerate( + body->getArguments().take_front(indexedOp.getNumLoops()))) { + Value index = rewriter.create(indexedOp.getLoc(), en.index()); + bvm.map(en.value(), index); + } + } + + // Create a generic replacement operation and clone the body. + rewriter.setInsertionPointAfter(indexedOp); + SmallVector iterators = llvm::to_vector<4>( + indexedOp.iterator_types().getAsValueRange()); + GenericOp genericOp = rewriter.create( + indexedOp.getLoc(), indexedOp->getResultTypes(), indexedOp.getInputs(), + indexedOp.getOutputs(), indexedOp.getIndexingMaps(), iterators); + Region &genericRegion = genericOp.region(); + Region &indexedRegion = indexedOp.region(); + rewriter.cloneRegionBefore(indexedRegion, genericRegion, + genericRegion.begin(), bvm); + + rewriter.replaceOp(indexedOp, genericOp->getResults()); + return success(); + } +}; +} // namespace + +void IndexedGenericOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// @@ -3161,7 +3203,6 @@ CANONICALIZERS_AND_FOLDERS(PoolingSumOp) CANONICALIZERS_AND_FOLDERS(CopyOp) CANONICALIZERS_AND_FOLDERS(FillOp) CANONICALIZERS_AND_FOLDERS(GenericOp) -CANONICALIZERS_AND_FOLDERS(IndexedGenericOp) // All named ops canonicalizers and folders are auto-generated in the // .cpp.inc. diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index 757b7a1..198936a 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -96,10 +96,10 @@ func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { // CHECK-LABEL: func @multiple_results_indexed // CHECK: %[[RESULT0:.*]] = memref.alloc() : memref<4xi32> // CHECK: %[[RESULT1:.*]] = memref.alloc() : memref<4xi32> -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: ins(%{{.*}} : memref<4xi32>) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xi32>, memref<4xi32>) -// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32): +// CHECK-NEXT: ^bb0(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32): func @multiple_results_indexed(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) { %0, %1 = linalg.indexed_generic { diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir index de894b9..e314bc4 100644 --- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir @@ -100,8 +100,8 @@ func @multiple_different_redundant_args(%arg0: tensor, %arg1: tensor (d0)> // CHECK-LABEL: @indexed_generic func @indexed_generic(%arg0: tensor) -> tensor { - // CHECK: linalg.indexed_generic - // CHECK: ^bb0(%{{.*}}: index, %[[BBARG:.*]]: f32, %{{[a-zA-Z0-9]+}}: f32): + // CHECK: linalg.generic + // CHECK: ^bb0(%[[BBARG:.*]]: f32, %{{[a-zA-Z0-9]+}}: f32): // CHECK: addf %[[BBARG]], %[[BBARG]] %0 = linalg.indexed_generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor, tensor) diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 39fc2b9..536e361 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1088,3 +1088,36 @@ func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index, // CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]] // CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]] // CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]] + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> + +func @indexed_generic(%arg0: memref, %arg1: memref) { + linalg.indexed_generic { + indexing_maps = [#map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : memref) + outs(%arg1 : memref) { + ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index): + %0 = addi %arg4, %arg5 : index + %1 = addi %0, %arg6 : index + %2 = addi %1, %arg7 : index + linalg.yield %2 : index + } + return +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @indexed_generic +// CHECK-NEXT: linalg.generic { +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[ARG0:[A-Za-z0-9_]+]] : memref) +// CHECK-SAME: outs(%[[ARG1:[A-Za-z0-9_]+]] : memref) +// CHECK: ^bb0(%[[ARG2:[A-Za-z0-9_]+]]: index, %[[ARG3:[A-Za-z0-9_]+]]: index): +// CHECK-NEXT: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK-NEXT: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK-NEXT: %[[SUM0:.+]] = addi %[[IDX0]], %[[IDX1]] : index +// CHECK-NEXT: %[[SUM1:.+]] = addi %[[SUM0]], %[[ARG2]] : index +// CHECK-NEXT: %[[SUM2:.+]] = addi %[[SUM1]], %[[ARG3]] : index +// CHECK-NEXT: linalg.yield %[[SUM2]] : index diff --git a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir index 149ba74..38a4369 100644 --- a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir +++ b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir @@ -57,7 +57,7 @@ func @fuse_indexed_generic_consumer(%A: memref, // CHECK: linalg.generic // CHECK-NOT: addi // CHECK: addf -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK: index_cast // ----- @@ -177,8 +177,9 @@ func @fuse_indexed_generic_producer(%A: memref, // CHECK-LABEL: func @fuse_indexed_generic_producer // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = // CHECK-NOT: scf.parallel -// CHECK: linalg.indexed_generic -// CHECK: ^bb0([[i:%.*]]: index, [[j:%.*]]: index +// CHECK: linalg.generic +// CHECK: [[i:%.*]] = linalg.index 0 : index +// CHECK: [[j:%.*]] = linalg.index 1 : index // CHECK: [[i_new:%.*]] = addi [[i]], [[I]] : index // CHECK: [[j_new:%.*]] = addi [[j]], [[J]] : index // CHECK: {{.*}} = index_cast [[i_new]] : index to i32 @@ -295,8 +296,9 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref, // CHECK: [[C0:%.*]] = constant 0 : index // CHECK: scf.parallel ([[J:%.*]]) = // CHECK-NOT: scf.parallel -// CHECK: linalg.indexed_generic -// CHECK: ^bb0([[i:%.*]]: index, [[j:%.*]]: index +// CHECK: linalg.generic +// CHECK: [[i:%.*]] = linalg.index 0 : index +// CHECK: [[j:%.*]] = linalg.index 1 : index // CHECK: [[j_new:%.*]] = addi [[j]], [[J]] : index // CHECK: {{.*}} = index_cast [[i]] : index to i32 // CHECK: {{.*}} = index_cast [[j_new]] : index to i32 diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir index 1ba2d37..7b43c0f 100644 --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -231,11 +231,8 @@ func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @indexed_generic_op_constant_fusion // CHECK: %[[CST:.*]] = constant {{.*}} : f32 -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: index // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32) // CHECK: mulf %[[CST]], %[[ARG4]] @@ -299,11 +296,8 @@ func @indexed_generic_op_zero_dim_constant_fusion // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion // CHECK: %[[CST:.*]] = constant {{.*}} : f32 -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: index // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32) // CHECK: mulf %[[CST]], %[[ARG4]] @@ -342,14 +336,14 @@ func @generic_op_indexed_generic_op_fusion(%arg0: tensor, } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @generic_op_indexed_generic_op_fusion -// CHECK-NOT: linalg.generic -// CHECK: linalg.indexed_generic +// CHECK-NOT: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] // CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[ARG0:.+]] = linalg.index 0 : index +// CHECK: %[[ARG1:.+]] = linalg.index 1 : index // CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ARG3]] : i32 // CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32 // CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32 @@ -444,13 +438,13 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor, } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @indexed_generic_op_generic_op_fusion -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] // CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[ARG0:.+]] = linalg.index 0 : index +// CHECK: %[[ARG1:.+]] = linalg.index 1 : index // CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32 // CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32 // CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND]] : i32 @@ -549,12 +543,12 @@ func @indexed_generic_op_fusion(%arg0: tensor) -> tensor { } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @indexed_generic_op_fusion -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] // CHECK: ^{{[a-zA-Z0-9_]*}} -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[ARG0:.+]] = linalg.index 0 : index +// CHECK: %[[ARG1:.+]] = linalg.index 1 : index // CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32 // CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32 // CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32 @@ -564,7 +558,7 @@ func @indexed_generic_op_fusion(%arg0: tensor) -> tensor { // CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32 // CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 // CHECK: linalg.yield %[[VAL4]] : i32 -// CHECK-NOT: linalg.indexed_generic +// CHECK-NOT: linalg.generic // ----- @@ -663,7 +657,7 @@ func @scalar_indexed_generic_fusion // CHECK: func @scalar_indexed_generic_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK: %[[T0:.+]] = linalg.indexed_generic +// CHECK: %[[T0:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel"] // CHECK-SAME: ins(%[[ARG1]] : tensor) diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 9855dc9..ff9aeeb 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -194,12 +194,14 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor, // checking the op body here. // CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)> // CHECK: func @indexed_generic_op_reshape_producer_fusion -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK: ^{{.*}}( -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32, // CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[ARG2:.+]] = linalg.index 0 : index +// CHECK: %[[ARG3:.+]] = linalg.index 1 : index +// CHECK: %[[ARG4:.+]] = linalg.index 2 : index +// CHECK: %[[ARG5:.+]] = linalg.index 3 : index // CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG3]], %[[ARG2]]) // CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]] // CHECK: %[[T5:.+]] = index_cast %[[T3]] @@ -290,12 +292,14 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor, // checking the op body here. // CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)> // CHECK: func @indexed_generic_op_reshape_consumer_fusion -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK: ^{{.*}}( -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32, // CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[ARG2:.+]] = linalg.index 0 : index +// CHECK: %[[ARG3:.+]] = linalg.index 1 : index +// CHECK: %[[ARG4:.+]] = linalg.index 2 : index +// CHECK: %[[ARG5:.+]] = linalg.index 3 : index // CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG5]], %[[ARG4]], %[[ARG3]]) // CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]] // CHECK: %[[T5:.+]] = index_cast %[[ARG2]] @@ -390,16 +394,19 @@ func @reshape_as_consumer_permutation // CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7] -// CHECK: %[[T4:.+]] = linalg.indexed_generic +// CHECK: %[[T4:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]] // CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) // CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>) // CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32, // CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[ARG2:.+]] = linalg.index 0 : index +// CHECK: %[[ARG3:.+]] = linalg.index 1 : index +// CHECK: %[[ARG4:.+]] = linalg.index 2 : index +// CHECK: %[[ARG5:.+]] = linalg.index 3 : index +// CHECK: %[[ARG6:.+]] = linalg.index 4 : index +// CHECK: %[[ARG7:.+]] = linalg.index 5 : index // CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[ARG3]], %[[ARG2]]) // CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[ARG6]], %[[ARG5]], %[[ARG4]]) // CHECK-DAG: %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]] @@ -508,16 +515,16 @@ func @reshape_as_producer_projected_permutation( // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)> // CHECK: @reshape_as_producer_projected_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32> -// CHECK: %[[RES:.+]] = linalg.indexed_generic +// CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>) // CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32, // CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[ARG1:.+]] = linalg.index 0 : index +// CHECK: %[[ARG2:.+]] = linalg.index 1 : index +// CHECK: %[[ARG3:.+]] = linalg.index 2 : index +// CHECK: %[[ARG4:.+]] = linalg.index 3 : index // CHECK: %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG2]], %[[ARG1]]) // CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32 // CHECK: %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32 diff --git a/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir b/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir index 660236a..0c0e7b0 100644 --- a/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir +++ b/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir @@ -24,8 +24,9 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) // TILE-10n25-LABEL: func @indexed_generic_vector // TILE-10n25: %[[C10:.*]] = constant 10 : index // TILE-10n25: scf.for %[[J:.*]] = {{.*}} step %[[C10]] -// TILE-10n25: linalg.indexed_generic -// TILE-10n25: ^bb0(%[[I:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32) +// TILE-10n25: linalg.generic +// TILE-10n25: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32) +// TILE-10n25: %[[I:.*]] = linalg.index 0 : index // TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index // TILE-10n25: %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32 // TILE-10n25: %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32 @@ -34,8 +35,9 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) // TILE-25n0-LABEL: func @indexed_generic_vector // TILE-25n0: %[[C25:.*]] = constant 25 : index // TILE-25n0: scf.for %[[J:.*]] = {{.*}} step %[[C25]] -// TILE-25n0: linalg.indexed_generic -// TILE-25n0: ^bb0(%[[I:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32) +// TILE-25n0: linalg.generic +// TILE-25n0: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32) +// TILE-25n0: %[[I:.*]] = linalg.index 0 : index // TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index // TILE-25n0: %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32 // TILE-25n0: %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32 @@ -43,7 +45,7 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) // TILE-0n25-LABEL: func @indexed_generic_vector // TILE-0n25-NOT: scf.for %[[J:.*]] = {{.*}} step % -// TILE-0n25: linalg.indexed_generic +// TILE-0n25: linalg.generic #combined_indices_trait = { args_in = 1, @@ -73,8 +75,10 @@ func @indexed_generic_matrix(%operand: memref<50x99xf32>, %result: memref<50x50x // TILE-10n25-DAG: %[[C10:.*]] = constant 10 : index // TILE-10n25: scf.for %[[K:.*]] = {{.*}} step %[[C10]] // TILE-10n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]] -// TILE-10n25: linalg.indexed_generic -// TILE-10n25: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// TILE-10n25: linalg.generic +// TILE-10n25: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// TILE-10n25: %[[I:.*]] = linalg.index 0 : index +// TILE-10n25: %[[J:.*]] = linalg.index 1 : index // TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[K]] : index // TILE-10n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index // TILE-10n25: %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32 @@ -86,8 +90,10 @@ func @indexed_generic_matrix(%operand: memref<50x99xf32>, %result: memref<50x50x // TILE-25n0-LABEL: func @indexed_generic_matrix // TILE-25n0: %[[C25:.*]] = constant 25 : index // TILE-25n0: scf.for %[[L:.*]] = {{.*}} step %[[C25]] -// TILE-25n0: linalg.indexed_generic -// TILE-25n0: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// TILE-25n0: linalg.generic +// TILE-25n0: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// TILE-25n0: %[[I:.*]] = linalg.index 0 : index +// TILE-25n0: %[[J:.*]] = linalg.index 1 : index // TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[L]] : index // TILE-25n0: %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32 // TILE-25n0: %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32 @@ -98,8 +104,10 @@ func @indexed_generic_matrix(%operand: memref<50x99xf32>, %result: memref<50x50x // TILE-0n25-LABEL: func @indexed_generic_matrix // TILE-0n25: %[[C25:.*]] = constant 25 : index // TILE-0n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]] -// TILE-0n25: linalg.indexed_generic -// TILE-0n25: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// TILE-0n25: linalg.generic +// TILE-0n25: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// TILE-0n25: %[[I:.*]] = linalg.index 0 : index +// TILE-0n25: %[[J:.*]] = linalg.index 1 : index // TILE-0n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index // TILE-0n25: %[[INT_I:.*]] = index_cast %[[I]] : index to i32 // TILE-0n25: %[[FLOAT_I:.*]] = sitofp %[[INT_I]] : i32 to f32 diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir index 744cc59..88a9182 100644 --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -164,7 +164,7 @@ func @indexed_generic_op_tensors( // CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][{{.+}}] : tensor to tensor // CHECK: %[[STARG1:.+]] = subtensor %[[ARG1]][{{.+}}] : tensor to tensor // CHECK: %[[STARG2:.+]] = subtensor %[[TC2]][{{.+}}] : tensor to tensor -// CHECK: %[[STRETURN:.+]] = linalg.indexed_generic +// CHECK: %[[STRETURN:.+]] = linalg.generic // CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[STARG2]] : tensor) // CHECK: %[[TD:.+]] = subtensor_insert %[[STRETURN]] into %[[TC2]] -- 2.7.4